diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn_padded.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn_padded.hpp index 8adc526afd..ad53cd89c3 100644 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn_padded.hpp +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn_padded.hpp @@ -47,6 +47,19 @@ template struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_padded { + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + static constexpr auto I8 = Number<8>{}; + static constexpr auto I9 = Number<9>{}; + static constexpr auto I10 = Number<10>{}; + static constexpr auto I11 = Number<11>{}; + #if 0 __device__ void Run(const Float* const __restrict__ p_in_global, const Float* const __restrict__ p_wei_global, @@ -60,11 +73,6 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_padded GemmNPerThreadSubC % NPerThread == 0)), "wrong!"); - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - constexpr auto True = integral_constant{}; constexpr auto False = integral_constant{}; @@ -487,58 +495,67 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_padded Float* const __restrict__ p_out_global) const { #if 0 - constexpr auto tmp = std::tuple{}; - constexpr auto flag = std::get<0>(tmp); -#else - constexpr auto a = Tuple, index_t>(true, Sequence<1>{}, 99); + constexpr auto a = make_tuple(true, Sequence<1>{}, index_t(99)); if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) { - printf("adsas %d\n", a.At(Number<0>{})); - print_Sequence("seq", a.At(Number<1>{})); - printf("adsas %lu\n", a.At(Number<2>{})); + printf("[0] %d\n", a.At(I0)); + print_Sequence("[1]", a.At(I1)); + printf("[2] %lu\n", a.At(I2)); } - auto b = Tuple, index_t>(true, Sequence<1>{}, 99); + bool flag = true; - b.At(Number<0>{}) = false; + auto b = make_tuple(flag, Sequence<1>{}, 99); + + b.At(I0) = false; if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) { - printf("adsas %d\n", b.At(Number<0>{})); - print_Sequence("seq", b.At(Number<1>{})); - printf("adsas %lu\n", b.At(Number<2>{})); + printf("[0] %d\n", b.At(I0)); + print_Sequence("[1]", b.At(I1)); + printf("[2] %lu\n", b.At(I2)); + + printf("flag %d\n", flag); } if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) { - printf("adsas %d\n", - Tuple, index_t>(true, Sequence<1>(), 99).At(Number<0>{})); - print_Sequence( - "seq", Tuple, index_t>(true, Sequence<1>(), 99).At(Number<1>{})); - printf("adsas %d\n", - Tuple, index_t>(true, Sequence<1>(), 99).At(Number<2>{})); + printf("[0] %d\n", make_tuple(true, Sequence<1>(), index_t(99)).At(I0)); + print_Sequence("[1]", make_tuple(true, Sequence<1>(), index_t(99)).At(I1)); + printf("[2] %d\n", make_tuple(true, Sequence<1>(), index_t(99)).At(I2)); } -#endif - -#if 0 - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - +#elif 1 // create a native tensor descriptor - constexpr auto in_n_c_h_w_global_desc = + constexpr auto in_c_h_w_n_global_desc = make_NativeTensorDescriptor(InGlobalDesc::GetLengths(), InGlobalDesc::GetStrides()); + constexpr index_t C = in_c_h_w_n_global_desc.GetLength(I0); + constexpr index_t Hi = in_c_h_w_n_global_desc.GetLength(I1); + constexpr index_t Wi = in_c_h_w_n_global_desc.GetLength(I2); + constexpr index_t N = in_c_h_w_n_global_desc.GetLength(I3); + + constexpr auto pad_h_w = Pad, LowerPads, UpperPads>{}; + constexpr auto pass_c = PassThrough{}; + constexpr auto pass_n = PassThrough{}; + + constexpr auto trans = make_tuple(pass_c, pad_h_w, pass_n); + constexpr auto lower_dim_groups = + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}); + constexpr auto upper_dim_groups = + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}); + + constexpr auto in_c_h_w_n_padded_global_desc = transform_tensor_descriptor( + in_c_h_w_n_global_desc, trans, lower_dim_groups, upper_dim_groups); + if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) { - print_tensor_descriptor("in_n_c_h_w_global_desc", in_n_c_h_w_global_desc); - } + print_tensor_descriptor("in_c_h_w_n_global_desc", in_c_h_w_n_global_desc); - // transform the tensor descriptor once - // - // calculate the offset of some entry + printf("offset: %lu\n", in_c_h_w_n_global_desc.GetOffset({1, 2, 3, 4})); + + printf("padded offset: %lu\n", in_c_h_w_n_padded_global_desc.GetOffset({1, 4, 5, 4})); + } #endif } #endif diff --git a/composable_kernel/include/tensor_description/ConstantTensorDescriptor.hpp b/composable_kernel/include/tensor_description/ConstantTensorDescriptor.hpp index 3949652174..b0f6d32ec8 100644 --- a/composable_kernel/include/tensor_description/ConstantTensorDescriptor.hpp +++ b/composable_kernel/include/tensor_description/ConstantTensorDescriptor.hpp @@ -178,7 +178,7 @@ struct ConstantTensorDescriptor { constexpr auto IDim = IDim_{}; constexpr index_t stride = PackedStrides::Get(IDim); - multi_id.Set(IDim, id / stride); + multi_id(IDim) = id / stride; id -= multi_id[IDim] * stride; } }; @@ -192,7 +192,7 @@ struct ConstantTensorDescriptor // calculate index in each of the dimensions in the order of their dimension static_for<0, nDim - 1, 1>{}(lambda_GetMultiIndexFrom1dIndex(id, multi_id)); - multi_id.Set(Number{}, id / PackedStrides::Get(Number{})); + multi_id(Number{}) = id / PackedStrides::Get(Number{}); return multi_id; } diff --git a/composable_kernel/include/tensor_description/multi_index_transform.hpp b/composable_kernel/include/tensor_description/multi_index_transform.hpp index 8812779319..c21f055390 100644 --- a/composable_kernel/include/tensor_description/multi_index_transform.hpp +++ b/composable_kernel/include/tensor_description/multi_index_transform.hpp @@ -33,7 +33,7 @@ struct PassThrough }; // LowLengths: Sequence<...> -template +template struct Pad { static constexpr index_t nDim = LowLengths::GetSize(); @@ -67,7 +67,7 @@ struct Pad #if 0 // LowLengths: Sequence<...> -template +template struct Merge { static constexpr index_t nDimLow = LowLengths::GetSize(); @@ -113,7 +113,7 @@ struct Merge #endif // UpLengths: Sequence<...> -template +template struct Unmerge { static constexpr index_t nDimLow = 1; @@ -161,7 +161,7 @@ struct Unmerge // UpLengths: Sequence<...> // Coefficients: Sequence<...> // idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1] + coefficients[nDimUp] -template +template struct Embed { static constexpr index_t nDimLow = 1; diff --git a/composable_kernel/include/tensor_description/tensor_descriptor.hpp b/composable_kernel/include/tensor_description/tensor_descriptor.hpp index 4610fc5f74..952b378151 100644 --- a/composable_kernel/include/tensor_description/tensor_descriptor.hpp +++ b/composable_kernel/include/tensor_description/tensor_descriptor.hpp @@ -7,12 +7,12 @@ namespace ck { -template +template struct NativeTensorDescriptor { using type = NativeTensorDescriptor; - static constexpr auto mDimensions = Tuple{}; - static constexpr index_t nDim = mDimensions.GetSize(); + static constexpr index_t nDim = sizeof...(NativeDimensions); + static constexpr auto mDimensions = make_tuple(NativeDimensions{}...); using Index = MultiIndex; @@ -20,7 +20,7 @@ struct NativeTensorDescriptor struct lambda_GetLength { - template + template __host__ __device__ constexpr auto operator()(IDim) const { return GetLength(IDim{}); @@ -34,7 +34,7 @@ struct NativeTensorDescriptor struct lambda_GetStride { - template + template __host__ __device__ constexpr auto operator()(IDim) const { return GetStride(IDim{}); @@ -49,16 +49,16 @@ struct NativeTensorDescriptor template __host__ __device__ static constexpr auto GetLength(Number) { - return mDimensions.Get(Number{}).GetLength(); + return mDimensions.At(Number{}).GetLength(); } template __host__ __device__ static constexpr auto GetStride(Number) { - return mDimensions.Get(Number{}).GetStride(); + return mDimensions.At(Number{}).GetStride(); } - __host__ __device__ static constexpr index_t GetOffset(Index idx) + __host__ __device__ static constexpr index_t GetOffset(const Index& idx) { index_t offset = 0; @@ -67,7 +67,7 @@ struct NativeTensorDescriptor return offset; } - __host__ __device__ static constexpr index_t GetOffsetDiff(Index idx_diff) + __host__ __device__ static constexpr index_t GetOffsetDiff(const Index& idx_diff) { index_t offset_diff = 0; @@ -96,28 +96,65 @@ struct NativeTensorDescriptor } }; -#if 0 // LowerTensorDescriptor -// Transforms: std::tuple -// LowerDimensionIds: std::tuple> -// UpperDimensionIds: std::tuple> -template +// Transforms: Tuple +// LowerDimensionIds: Tuple> +// UpperDimensionIds: Tuple> +template struct TransformedTensorDescriptor { - using type = TransformedTensorDescriptor; - static constexpr index_t nDimUp = GetUpperNumOfDimension(); - static constexpr index_t nDimLow = GetLowerNumOfDimension(); + using type = TransformedTensorDescriptor; + static constexpr index_t nTransform = Transforms::Size(); - static constexpr index_t nTransform = Transforms::GetSize(); + struct lambda_merge_sequences + { + template + __host__ __device__ constexpr auto operator()(Seqs... seqs) const + { + return merge_sequences(seqs...); + } + }; + + __host__ __device__ static constexpr auto GetNumOfLowerDimension() + { + // Here, we assume all lower-dimensions are active + // TODO: sanity-check all lower-dimension are indeed active + + using duplicated_low_active_dims = + decltype(unpack(lambda_merge_sequences{}, LowDimensionIds{})); + + using low_active_dims = typename sequence_unique_sort, + math::equal>::type; + + return low_active_dims::Size(); + } + + __host__ __device__ static constexpr auto GetNumOfUpperDimension() + { + using duplicated_up_active_dims = + decltype(unpack(lambda_merge_sequences{}, UpDimensionIds{})); + + using up_active_dims = typename sequence_unique_sort, + math::equal>::type; + + return up_active_dims::Size(); + } + + static constexpr index_t nDimUp = GetNumOfUpperDimension(); + static constexpr index_t nDimLow = GetNumOfLowerDimension(); using UpperIndex = MultiIndex; using LowerIndex = MultiIndex; - __host__ __device__ static constexpr TransformedTensorDescriptor() + __host__ __device__ constexpr TransformedTensorDescriptor() { - static_assert(nTransform == Transforms::GetSize() && - nTransform == LowDimensionIds::GetSize() && - nTransform == UpDimensionIds::GetSize(), + static_assert(nTransform == Transforms::Size() && nTransform == LowDimensionIds::Size() && + nTransform == UpDimensionIds::Size(), "wrong! # of transformations not the same"); // TODO: sanity check: LowDimensionIds should include all low-dimensions, @@ -128,33 +165,17 @@ struct TransformedTensorDescriptor // a low-dimension should be associated with only one transformation } - __host__ __device__ static constexpr auto GetNumOfLowerDimension() - { - // Here, we assume all lower-dimensions are active - // TODO: sanity-check all lower-dimension are indeed active - constexpr auto low_active_dims = unique_sort_sequence( - merge_tuple_of_sequences(LowDimensionIds{}), math::less{}); - - return low_active_dims.GetSize(); - } - - __host__ __device__ static constexpr auto GetNumOfUpperDimension() - { - constexpr auto up_active_dims = - unique_sort_sequence(merge_tuple_of_sequences(UpDimensionIds{}), math::less{}); - return up_active_dims.GetSize(); - } - __host__ __device__ static constexpr auto GetNumOfDimension() { return GetNumOfUpperDimension(); } - __host__ __device__ static constexpr auto GetLengths() +#if 0 + __host__ __device__ static constexpr auto GetUpperLengths() { struct lambda_get_upper_lengths { - template + template __host__ __device__ constexpr auto operator()(Transform tran) const { return tran.GetUpperLengths(); @@ -173,6 +194,7 @@ struct TransformedTensorDescriptor using sort_dimension_ids = sequence_unique_sort>; + constexpr auto sorted_upper_dimension_ids = typename sort_dimension_ids::type; constexpr auto sorted2unsorted_map = typename sort_dimension_ids::sorted2unsorted_map_type; @@ -182,46 +204,48 @@ struct TransformedTensorDescriptor return sorted_upper_lengths; } + __host__ __device__ static constexpr auto GetLengths() { return GetUpperLengths(); } +#endif + __host__ __device__ static constexpr auto GetLowerTensorDescriptor() { return LowTensorDescriptor{}; } - __host__ __device__ static constexpr index_t GetLowerIndex(UpperIndex idx_up) + __host__ __device__ static constexpr LowerIndex GetLowerIndex(const UpperIndex& idx_up) { LowerIndex idx_low; static_for<0, nTransform, 1>{}([&](auto itran) { - constexpr auto tran = Transforms::Get(itran); + constexpr auto tran = Transforms{}.At(itran); - constexpr auto idx_low_part = pick_array_element(idx_low, LowDimensionIds::Get(itran)); - constexpr auto idx_up_part = pick_array_element(idx_up, UpDimensionIds::Get(itran)); + auto idx_low_part = pick_array_element(idx_low, LowDimensionIds{}.At(itran)); + const auto idx_up_part = pick_array_element(idx_up, UpDimensionIds{}.At(itran)); // this assume each lower (single) index is only assocaited with one transformation, // which is required for index transformation, and has been checked during constructor // of TransformedTensorDescriptor - idx_low_part = tran.GetLowerIndex(idx_up_part); + idx_low_part = tran.GetLowerIndex(to_array(idx_up_part)); }); return idx_low; } - __host__ __device__ static constexpr index_t GetLowerIndexDiff(UpperIndex idx_up_diff, - LowerIndex idx_low_old) + __host__ __device__ static constexpr LowerIndex GetLowerIndexDiff(const UpperIndex& idx_up_diff, + const LowerIndex& idx_low_old) { LowerIndex idx_low_diff; static_for<0, nTransform, 1>{}([&](auto itran) { - constexpr auto tran = Transforms::Get(itran); + constexpr auto tran = Transforms::At(itran); - constexpr auto idx_up_diff_part = - pick_array_element(idx_up_diff, UpDimensionIds::Get(itran)); + const auto idx_up_diff_part = + pick_array_element(idx_up_diff, UpDimensionIds::At(itran)); - constexpr auto idx_low_diff_part = - pick_array_element(idx_low_diff, LowDimensionIds::Get(itran)); + auto idx_low_diff_part = pick_array_element(idx_low_diff, LowDimensionIds::At(itran)); - constexpr auto idx_low_old_part = - pick_array_element(idx_low_old, LowDimensionIds::Get(itran)); + const auto idx_low_old_part = + pick_array_element(idx_low_old, LowDimensionIds::At(itran)); // this assume each lower (single) index is associated with only one transformation, // which is required for index transformation, and has been checked during constructor @@ -232,13 +256,14 @@ struct TransformedTensorDescriptor return idx_low_diff; } - __host__ __device__ static constexpr index_t GetOffset(UpperIndex idx_up) + __host__ __device__ static constexpr index_t GetOffset(const UpperIndex& idx_up) { return GetLowerTensorDescriptor().GetOffset(GetLowerIndex(idx_up)); } +#if 0 template - __host__ __device__ static constexpr bool IsLinearDimension(Number); + __host__ __device__ static constexpr bool IsLinearDimension(Number) { // not implemented } @@ -257,8 +282,8 @@ struct TransformedTensorDescriptor { // not implemented } -}; #endif +}; template __host__ __device__ constexpr auto make_NativeTensorDescriptor(Sequence, @@ -267,15 +292,28 @@ __host__ __device__ constexpr auto make_NativeTensorDescriptor(Sequence...>{}; } -template +template __host__ __device__ constexpr auto make_NativeTensorDescriptor_packed(Lengths) { - constexpr index_t strides = reverse_inclusive_scan_sequence( - Lengths::PopFront(), math::multiplies{}, Number<1>{}) - .PushBack(Number<1>{}); + constexpr auto strides = reverse_inclusive_scan_sequence( + Lengths::PopFront(), math::multiplies{}, Number<1>{}) + .PushBack(Number<1>{}); return make_NativeTensorDescriptor(Lengths{}, strides); } +template +__host__ __device__ constexpr auto + transform_tensor_descriptor(LowTensorDescriptor, Transforms, LowDimensionIds, UpDimensionIds) +{ + return TransformedTensorDescriptor{}; +} + } // namespace ck #endif diff --git a/composable_kernel/include/tensor_description/tensor_descriptor_helper.hpp b/composable_kernel/include/tensor_description/tensor_descriptor_helper.hpp index f1f662f772..04009a740c 100644 --- a/composable_kernel/include/tensor_description/tensor_descriptor_helper.hpp +++ b/composable_kernel/include/tensor_description/tensor_descriptor_helper.hpp @@ -6,7 +6,7 @@ namespace ck { -template +template __host__ __device__ void print_tensor_descriptor(const char* s, NativeTensorDescriptor desc) { diff --git a/composable_kernel/include/utility/Array.hpp b/composable_kernel/include/utility/Array.hpp index 8b94d46b83..1cc8d4d0d6 100644 --- a/composable_kernel/include/utility/Array.hpp +++ b/composable_kernel/include/utility/Array.hpp @@ -6,48 +6,78 @@ namespace ck { -template +template struct Array { - using Type = Array; + using type = Array; using data_type = TData; - static constexpr index_t nSize = NSize; + index_t mData[NSize]; - index_t mData[nSize]; + __host__ __device__ explicit constexpr Array() {} - template - __host__ __device__ constexpr Array(Xs... xs) : mData{static_cast(xs)...} + template + __host__ __device__ explicit constexpr Array(X x, Xs... xs) + : mData{static_cast(x), static_cast(xs)...} { + static_assert(sizeof...(Xs) + 1 == NSize, "wrong! size"); } - __host__ __device__ static constexpr index_t GetSize() { return NSize; } - - template - __host__ __device__ constexpr TData operator[](Number) const +#if 0 + template + __host__ __device__ explicit constexpr Array(const T& x) { - return mData[I]; - } + static_assert(T::Size() == NSize, "wrong! size"); - __host__ __device__ constexpr TData operator[](index_t i) const { return mData[i]; } + static_for<0, NSize, 1>{}([&](auto i){ + mData[i] = x.At(i); + }) + } +#endif + + __host__ __device__ static constexpr index_t Size() { return NSize; } + + __host__ __device__ static constexpr index_t GetSize() { return Size(); } template - __host__ __device__ TData& operator()(Number) - { - return mData[I]; - } - - __host__ __device__ TData& operator()(index_t i) { return mData[i]; } - - template - __host__ __device__ constexpr void Set(Number, TData x) + __host__ __device__ constexpr const TData& At(Number) const { static_assert(I < NSize, "wrong!"); - mData[I] = x; + return mData[I]; } - __host__ __device__ constexpr void Set(index_t I, TData x) { mData[I] = x; } + template + __host__ __device__ constexpr TData& At(Number) + { + static_assert(I < NSize, "wrong!"); + + return mData[I]; + } + + __host__ __device__ constexpr const TData& At(index_t i) const { return mData[i]; } + + __host__ __device__ constexpr TData& At(index_t i) { return mData[i]; } + + template + __host__ __device__ constexpr const TData& operator[](I i) const + { + return At(i); + } + + template + __host__ __device__ constexpr TData& operator()(I i) + { + return At(i); + } + + template + __host__ __device__ constexpr type& operator=(const T& x) + { + static_for<0, Size(), 1>{}([&](auto i) { operator()(i) = x[i]; }); + + return *this; + } struct lambda_PushBack // emulate constexpr lambda { @@ -63,7 +93,7 @@ struct Array template __host__ __device__ constexpr void operator()(Number) const { - new_array.Set(Number{}, old_array[I]); + new_array(Number{}) = old_array[I]; } }; @@ -73,71 +103,98 @@ struct Array static_for<0, NSize, 1>{}(lambda_PushBack(*this, new_array)); - new_array.Set(Number{}, x); + new_array(Number{}) = x; return new_array; } }; -// A: Array +// Arr: Array // Picks: Sequence<...> -template +template struct ArrayElementPicker { + using type = ArrayElementPicker; using data_type = typename Arr::data_type; - __host__ __device__ constexpr ArrayElementPicker(Arr& array) : mData{array} + __host__ __device__ constexpr ArrayElementPicker() = delete; + + __host__ __device__ explicit constexpr ArrayElementPicker(Arr& array) : mArray{array} { constexpr index_t imax = accumulate_on_sequence(Picks{}, math::maxer{}, Number<0>{}); - static_assert(imax < Picks::GetSize(), "wrong! exceeding max id"); + static_assert(imax < Arr::Size(), "wrong! exceeding # array element"); } - __host__ __device__ static constexpr index_t GetSize() { return Picks::GetSize(); } + __host__ __device__ static constexpr auto Size() { return Picks::Size(); } template - __host__ __device__ constexpr data_type operator[](Number) const + __host__ __device__ constexpr const data_type& At(Number) const { - constexpr auto IP = Picks::Get(Number{}); - return mData[IP]; - } + static_assert(I < Size(), "wrong!"); - __host__ __device__ constexpr data_type operator[](index_t i) const - { - constexpr index_t ip = Picks{}[i]; - return mData[ip]; + constexpr auto IP = Picks{}[I]; + return mArray[IP]; } template - __host__ __device__ data_type& operator()(Number) + __host__ __device__ constexpr data_type& At(Number) { - constexpr auto IP = Picks::Get(Number{}); - return mData[IP]; + static_assert(I < Size(), "wrong!"); + + constexpr auto IP = Picks{}[I]; + return mArray(IP); } - __host__ __device__ data_type& operator()(index_t i) + template + __host__ __device__ constexpr const data_type& operator[](I i) const { - constexpr index_t ip = Picks{}[i]; - return mData[ip]; + return At(i); } - Arr& mData; + template + __host__ __device__ constexpr data_type& operator()(I i) + { + return At(i); + } + + template + __host__ __device__ constexpr type& operator=(const T& a) + { + static_for<0, Size(), 1>{}([&](auto i) { operator()(i) = a[i]; }); + + return *this; + } + + Arr& mArray; }; -template +template __host__ __device__ constexpr auto pick_array_element(Arr& a, Picks) { return ArrayElementPicker(a); } +#if 1 +template +__host__ __device__ constexpr auto to_array(const T& x) +{ + Array y; + + static_for<0, T::Size(), 1>{}([&](auto i) { y.At(i) = x.At(i); }); + + return y; +} +#endif + template __host__ __device__ constexpr auto sequence2array(Sequence) { return Array{Is...}; } -template +template __host__ __device__ constexpr auto make_zero_array() { constexpr auto zero_sequence = typename uniform_sequence_gen::type{}; @@ -145,7 +202,7 @@ __host__ __device__ constexpr auto make_zero_array() return zero_array; } -template +template __host__ __device__ constexpr auto reorder_array_given_new2old(const Array& old_array, Sequence /*new2old*/) { @@ -156,7 +213,7 @@ __host__ __device__ constexpr auto reorder_array_given_new2old(const Array{old_array[IRs]...}; } -template +template struct lambda_reorder_array_given_old2new { const Array& old_array; @@ -173,13 +230,13 @@ struct lambda_reorder_array_given_old2new { TData old_data = old_array[IOldDim]; - constexpr index_t INewDim = MapOld2New::Get(Number{}); + constexpr index_t INewDim = MapOld2New::At(Number{}); - new_array.Set(Number{}, old_data); + new_array(Number{}) = old_data; } }; -template +template __host__ __device__ constexpr auto reorder_array_given_old2new(const Array& old_array, Sequence /*old2new*/) { @@ -195,7 +252,7 @@ __host__ __device__ constexpr auto reorder_array_given_old2new(const Array +template __host__ __device__ constexpr auto extract_array(const Array& old_array, ExtractSeq) { Array new_array; @@ -204,12 +261,13 @@ __host__ __device__ constexpr auto extract_array(const Array& old_ static_assert(new_size <= NSize, "wrong! too many extract"); - static_for<0, new_size, 1>{}([&](auto I) { new_array(I) = old_array[ExtractSeq::Get(I)]; }); + static_for<0, new_size, 1>{}([&](auto I) { new_array(I) = old_array[ExtractSeq::At(I)]; }); return new_array; } -template // emulate constepxr lambda for array math +template // emulate constepxr lambda for array +// math struct lambda_array_math { const F& f; @@ -226,13 +284,12 @@ struct lambda_array_math __host__ __device__ constexpr void operator()(Number) const { constexpr auto IDim = Number{}; - - z.Set(IDim, f(x[IDim], y[IDim])); + z(IDim) = f(x[IDim], y[IDim]); } }; // Array = Array + Array -template +template __host__ __device__ constexpr auto operator+(Array a, Array b) { Array result; @@ -247,7 +304,7 @@ __host__ __device__ constexpr auto operator+(Array a, Array +template __host__ __device__ constexpr auto operator-(Array a, Array b) { Array result; @@ -262,7 +319,7 @@ __host__ __device__ constexpr auto operator-(Array a, Array +template __host__ __device__ constexpr auto operator+=(Array& a, Array b) { a = a + b; @@ -270,14 +327,14 @@ __host__ __device__ constexpr auto operator+=(Array& a, Array +template __host__ __device__ constexpr auto operator-=(Array& a, Array b) { a = a - b; return a; } // Array = Array + Sequence -template +template __host__ __device__ constexpr auto operator+(Array a, Sequence b) { static_assert(sizeof...(Is) == NSize, "wrong! size not the same"); @@ -294,7 +351,7 @@ __host__ __device__ constexpr auto operator+(Array a, Sequence +template __host__ __device__ constexpr auto operator-(Array a, Sequence b) { static_assert(sizeof...(Is) == NSize, "wrong! size not the same"); @@ -311,7 +368,7 @@ __host__ __device__ constexpr auto operator-(Array a, Sequence +template __host__ __device__ constexpr auto operator*(Array a, Sequence b) { static_assert(sizeof...(Is) == NSize, "wrong! size not the same"); @@ -328,7 +385,7 @@ __host__ __device__ constexpr auto operator*(Array a, Sequence +template __host__ __device__ constexpr auto operator-(Sequence a, Array b) { static_assert(sizeof...(Is) == NSize, "wrong! size not the same"); @@ -344,7 +401,7 @@ __host__ __device__ constexpr auto operator-(Sequence a, Array +template __host__ __device__ constexpr TData accumulate_on_array(const Array& a, Reduce f, TData init) { @@ -357,89 +414,5 @@ accumulate_on_array(const Array& a, Reduce f, TData init) return result; } -template -__host__ __device__ void print_Array(const char* s, Array a) -{ - constexpr index_t nsize = a.GetSize(); - - static_assert(nsize > 0 && nsize <= 10, "wrong!"); - - static_if{}([&](auto) { printf("%s size %u, {%u}\n", s, nsize, a[0]); }); - - static_if{}([&](auto) { printf("%s size %u, {%u %u}\n", s, nsize, a[0], a[1]); }); - - static_if{}( - [&](auto) { printf("%s size %u, {%u %u %u}\n", s, nsize, a[0], a[1], a[2]); }); - - static_if{}( - [&](auto) { printf("%s size %u, {%u %u %u %u}\n", s, nsize, a[0], a[1], a[2], a[3]); }); - - static_if{}([&](auto) { - printf("%s size %u, {%u %u %u %u %u}\n", s, nsize, a[0], a[1], a[2], a[3], a[4]); - }); - - static_if{}([&](auto) { - printf("%s size %u, {%u %u %u %u %u %u}\n", s, nsize, a[0], a[1], a[2], a[3], a[4], a[5]); - }); - - static_if{}([&](auto) { - printf("%s size %u, {%u %u %u %u %u %u %u}\n", - s, - nsize, - a[0], - a[1], - a[2], - a[3], - a[4], - a[5], - a[6]); - }); - - static_if{}([&](auto) { - printf("%s size %u, {%u %u %u %u %u %u %u %u}\n", - s, - nsize, - a[0], - a[1], - a[2], - a[3], - a[4], - a[5], - a[6], - a[7]); - }); - - static_if{}([&](auto) { - printf("%s size %u, {%u %u %u %u %u %u %u %u %u}\n", - s, - nsize, - a[0], - a[1], - a[2], - a[3], - a[4], - a[5], - a[6], - a[7], - a[8]); - }); - - static_if{}([&](auto) { - printf("%s size %u, {%u %u %u %u %u %u %u %u %u %u}\n", - s, - nsize, - a[0], - a[1], - a[2], - a[3], - a[4], - a[5], - a[6], - a[7], - a[8], - a[9]); - }); -} - } // namespace ck #endif diff --git a/composable_kernel/include/utility/Sequence.hpp b/composable_kernel/include/utility/Sequence.hpp index 3abdceda16..b9327bdd81 100644 --- a/composable_kernel/include/utility/Sequence.hpp +++ b/composable_kernel/include/utility/Sequence.hpp @@ -12,22 +12,22 @@ struct static_for; template struct Sequence; -template +template struct sequence_split; -template +template struct sequence_reverse; -template +template struct sequence_map_inverse; -template +template struct is_valid_sequence_map; template __host__ __device__ constexpr auto sequence_pop_front(Sequence); -template +template __host__ __device__ constexpr auto sequence_pop_back(Seq); template @@ -38,9 +38,11 @@ struct Sequence static constexpr index_t mSize = sizeof...(Is); - __host__ __device__ static constexpr auto GetSize() { return Number{}; } + __host__ __device__ static constexpr auto Size() { return Number{}; } - __host__ __device__ static constexpr index_t GetImpl(index_t I) + __host__ __device__ static constexpr auto GetSize() { return Size(); } + + __host__ __device__ static constexpr index_t At(index_t I) { // the last dummy element is to prevent compiler complain about empty array, when mSize = 0 const index_t mData[mSize + 1] = {Is..., 0}; @@ -48,23 +50,24 @@ struct Sequence } template - __host__ __device__ static constexpr auto Get(Number) + __host__ __device__ static constexpr auto At(Number) { static_assert(I < mSize, "wrong! I too large"); - return Number{})>{}; + return Number{}; } - __host__ __device__ static constexpr auto Get(index_t I) { return GetImpl(I); } - template - __host__ __device__ constexpr auto operator[](Number) const + __host__ __device__ static constexpr auto Get(Number) { - return Get(Number{}); + return At(Number{}); } - // make sure I is constepxr if you want a constexpr return type - __host__ __device__ constexpr index_t operator[](index_t I) const { return GetImpl(I); } + template + __host__ __device__ constexpr auto operator[](I i) const + { + return At(i); + } template __host__ __device__ static constexpr auto ReorderGivenNew2Old(Sequence /*new2old*/) @@ -74,14 +77,14 @@ struct Sequence static_assert(is_valid_sequence_map>::value, "wrong! invalid reorder map"); - return Sequence{})...>{}; + return Sequence{})...>{}; } // MapOld2New is Sequence<...> - template + template __host__ __device__ static constexpr auto ReorderGivenOld2New(MapOld2New) { - static_assert(MapOld2New::GetSize() == GetSize(), + static_assert(MapOld2New::Size() == Size(), "wrong! reorder map should have the same size as Sequence to be rerodered"); static_assert(is_valid_sequence_map::value, "wrong! invalid reorder map"); @@ -97,13 +100,13 @@ struct Sequence __host__ __device__ static constexpr auto Front() { static_assert(mSize > 0, "wrong!"); - return Get(Number<0>{}); + return At(Number<0>{}); } __host__ __device__ static constexpr auto Back() { static_assert(mSize > 0, "wrong!"); - return Get(Number{}); + return At(Number{}); } __host__ __device__ static constexpr auto PopFront() { return sequence_pop_front(Type{}); } @@ -137,19 +140,19 @@ struct Sequence template __host__ __device__ static constexpr auto Extract(Number...) { - return Sequence{})...>{}; + return Sequence{})...>{}; } template __host__ __device__ static constexpr auto Extract(Sequence) { - return Sequence{})...>{}; + return Sequence{})...>{}; } template __host__ __device__ static constexpr auto Modify(Number, Number) { - static_assert(I < GetSize(), "wrong!"); + static_assert(I < Size(), "wrong!"); using seq_split = sequence_split; constexpr auto seq_left = typename seq_split::SeqType0{}; @@ -158,7 +161,7 @@ struct Sequence return seq_left.PushBack(Number{}).PushBack(seq_right); } - template + template __host__ __device__ static constexpr auto Transform(F f) { return Sequence{}; @@ -166,8 +169,11 @@ struct Sequence }; // merge sequence -template -struct sequence_merge; +template +struct sequence_merge +{ + using type = typename sequence_merge::type>::type; +}; template struct sequence_merge, Sequence> @@ -175,8 +181,14 @@ struct sequence_merge, Sequence> using type = Sequence; }; +template +struct sequence_merge +{ + using type = Seq; +}; + // generate sequence -template +template struct sequence_gen_impl { static constexpr index_t NRemainLeft = NRemain / 2; @@ -188,20 +200,20 @@ struct sequence_gen_impl typename sequence_gen_impl::type>::type; }; -template +template struct sequence_gen_impl { static constexpr index_t Is = F{}(Number{}); using type = Sequence; }; -template +template struct sequence_gen_impl { using type = Sequence<>; }; -template +template struct sequence_gen { using type = typename sequence_gen_impl<0, NSize, F>::type; @@ -235,10 +247,10 @@ struct uniform_sequence_gen }; // reverse inclusive scan (with init) sequence -template +template struct sequence_reverse_inclusive_scan; -template +template struct sequence_reverse_inclusive_scan, Reduce, Init> { using old_scan = typename sequence_reverse_inclusive_scan, Reduce, Init>::type; @@ -248,23 +260,23 @@ struct sequence_reverse_inclusive_scan, Reduce, Init> using type = typename sequence_merge, old_scan>::type; }; -template +template struct sequence_reverse_inclusive_scan, Reduce, Init> { using type = Sequence; }; -template +template struct sequence_reverse_inclusive_scan, Reduce, Init> { using type = Sequence<>; }; // split sequence -template +template struct sequence_split { - static constexpr index_t NSize = Seq{}.GetSize(); + static constexpr index_t NSize = Seq{}.Size(); using range0 = typename arithmetic_sequence_gen<0, I, 1>::type; using range1 = typename arithmetic_sequence_gen::type; @@ -274,10 +286,10 @@ struct sequence_split }; // reverse sequence -template +template struct sequence_reverse { - static constexpr index_t NSize = Seq{}.GetSize(); + static constexpr index_t NSize = Seq{}.Size(); using seq_split = sequence_split; using type = typename sequence_merge< @@ -297,19 +309,102 @@ struct sequence_reverse> using type = Sequence; }; -template +template struct sequence_sort { - // not implemented + template + struct sorted_sequence_merge_impl + { + static constexpr bool pick_left = SeqLeft::Front() < SeqRight::Front(); + static constexpr index_t next_value = pick_left ? SeqLeft::Front() : SeqRight::Front(); + + using new_merged_seq = decltype(MergedSeq::PushBack(Number{})); + + using new_left_seq = + typename conditional::type; + using new_right_seq = + typename conditional::type; + + using type = + typename sorted_sequence_merge_impl:: + type; + }; + + template + struct sorted_sequence_merge_impl, MergedSeq, Comp> + { + using type = typename sequence_merge::type; + }; + + template + struct sorted_sequence_merge_impl, SeqRight, MergedSeq, Comp> + { + using type = typename sequence_merge::type; + }; + + template + struct sorted_sequence_merge + { + using type = typename sorted_sequence_merge_impl, Comp>::type; + }; + + using split = sequence_split; + using unsorted_left = typename split::SeqType0; + using unsorted_right = typename split::SeqType1; + + using sorted_left = typename sequence_sort::type; + using sorted_right = typename sequence_sort::type; + + using type = typename sorted_sequence_merge::type; }; -template +template +struct sequence_sort, Compare> +{ + static constexpr bool x_first = Compare{}(X, Y); + + using type = typename conditional, Sequence>::type; +}; + +template +struct sequence_sort, Compare> +{ + using type = Sequence; +}; + +template struct sequence_unique_sort { - // not implemented + template + struct sorted_sequence_uniquify_impl + { + static constexpr index_t new_value = WorkInputSeq::Front(); + using new_work_input_seq = decltype(WorkInputSeq::PopFront()); + + using new_working_output_seq = + typename conditional{}))>::type; + }; + + template + struct sorted_sequence_uniquify_impl, Eq> + { + using type = WorkInputSeq; + }; + + template + struct sorted_sequence_uniquify + { + using type = typename sorted_sequence_uniquify_impl, Eq>::type; + }; + + using sorted_seq = typename sequence_sort::type; + + using type = typename sorted_sequence_uniquify::type; }; -template +template struct is_valid_sequence_map { // not implemented yet, always return true @@ -317,36 +412,35 @@ struct is_valid_sequence_map // TODO: add proper check for is_valid, something like: // static constexpr bool value = - // is_same::type, + // is_same::type, // typename sequence_sort::SortedSeqType>{}; }; -template +template struct sequence_map_inverse_impl { private: - static constexpr auto new_y2x = - WorkingY2X::Modify(X2Y::Get(Number{}), Number{}); + static constexpr auto new_y2x = WorkingY2X::Modify(X2Y::At(Number{}), Number{}); public: using type = typename sequence_map_inverse_impl::type; }; -template +template struct sequence_map_inverse_impl { using type = WorkingY2X; }; -template +template struct sequence_map_inverse { using type = typename sequence_map_inverse_impl::type, + typename uniform_sequence_gen::type, 0, - X2Y::GetSize()>::type; + X2Y::Size()>::type; }; template @@ -457,20 +551,26 @@ __host__ __device__ constexpr auto sequence_pop_front(Sequence) return Sequence{}; } -template +template __host__ __device__ constexpr auto sequence_pop_back(Seq) { - static_assert(Seq::GetSize() > 0, "wrong! cannot pop an empty Sequence!"); + static_assert(Seq::Size() > 0, "wrong! cannot pop an empty Sequence!"); return sequence_pop_front(Seq::Reverse()).Reverse(); } -template +template __host__ __device__ constexpr auto transform_sequences(F f, Sequence) { return Sequence{}; } -template +template +__host__ __device__ constexpr auto merge_sequences(Seqs...) +{ + return typename sequence_merge::type{}; +} + +template __host__ __device__ constexpr auto transform_sequences(F f, Sequence, Sequence) { static_assert(Sequence::mSize == Sequence::mSize, "Dim not the same"); @@ -478,7 +578,7 @@ __host__ __device__ constexpr auto transform_sequences(F f, Sequence, Seq return Sequence{}; } -template +template __host__ __device__ constexpr auto transform_sequences(F f, Sequence, Sequence, Sequence) { @@ -489,19 +589,19 @@ transform_sequences(F f, Sequence, Sequence, Sequence) return Sequence{}; } -template +template __host__ __device__ constexpr auto reverse_inclusive_scan_sequence(Seq, Reduce, Number) { return typename sequence_reverse_inclusive_scan::type{}; } -template +template __host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce, Number) { return reverse_inclusive_scan_sequence(Seq{}.Reverse(), Reduce{}, Number{}).Reverse(); } -template +template struct lambda_accumulate_on_sequence { const Reduce& f; @@ -512,14 +612,14 @@ struct lambda_accumulate_on_sequence { } - template + template __host__ __device__ constexpr index_t operator()(IDim) const { - return result = f(result, Seq::Get(IDim{})); + return result = f(result, Seq::At(IDim{})); } }; -template +template __host__ __device__ constexpr index_t accumulate_on_sequence(Seq, Reduce f, Number /*initial_value*/) { @@ -530,41 +630,5 @@ accumulate_on_sequence(Seq, Reduce f, Number /*initial_value*/) return result; } -template -__host__ __device__ void print_Sequence(const char* s, Sequence) -{ - constexpr index_t nsize = Sequence::GetSize(); - - static_assert(nsize <= 10, "wrong!"); - - static_if{}([&](auto) { printf("%s size %u, {}\n", s, nsize, Xs...); }); - - static_if{}([&](auto) { printf("%s size %u, {%u}\n", s, nsize, Xs...); }); - - static_if{}([&](auto) { printf("%s size %u, {%u %u}\n", s, nsize, Xs...); }); - - static_if{}([&](auto) { printf("%s size %u, {%u %u %u}\n", s, nsize, Xs...); }); - - static_if{}([&](auto) { printf("%s size %u, {%u %u %u %u}\n", s, nsize, Xs...); }); - - static_if{}( - [&](auto) { printf("%s size %u, {%u %u %u %u %u}\n", s, nsize, Xs...); }); - - static_if{}( - [&](auto) { printf("%s size %u, {%u %u %u %u %u %u}\n", s, nsize, Xs...); }); - - static_if{}( - [&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u}\n", s, nsize, Xs...); }); - - static_if{}( - [&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u %u}\n", s, nsize, Xs...); }); - - static_if{}( - [&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u %u %u}\n", s, nsize, Xs...); }); - - static_if{}( - [&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u %u %u %u}\n", s, nsize, Xs...); }); -} - } // namespace ck #endif diff --git a/composable_kernel/include/utility/array_helper.hpp b/composable_kernel/include/utility/array_helper.hpp new file mode 100644 index 0000000000..55f022be7d --- /dev/null +++ b/composable_kernel/include/utility/array_helper.hpp @@ -0,0 +1,93 @@ +#ifndef CK_ARRAY_HELPER_HPP +#define CK_ARRAY_HELPER_HPP + +#include "Array.hpp" + +namespace ck { + +template +__host__ __device__ void print_Array(const char* s, Array a) +{ + constexpr index_t nsize = a.GetSize(); + + static_assert(nsize > 0 && nsize <= 10, "wrong!"); + + static_if{}([&](auto) { printf("%s size %u, {%u}\n", s, nsize, a[0]); }); + + static_if{}([&](auto) { printf("%s size %u, {%u %u}\n", s, nsize, a[0], a[1]); }); + + static_if{}( + [&](auto) { printf("%s size %u, {%u %u %u}\n", s, nsize, a[0], a[1], a[2]); }); + + static_if{}( + [&](auto) { printf("%s size %u, {%u %u %u %u}\n", s, nsize, a[0], a[1], a[2], a[3]); }); + + static_if{}([&](auto) { + printf("%s size %u, {%u %u %u %u %u}\n", s, nsize, a[0], a[1], a[2], a[3], a[4]); + }); + + static_if{}([&](auto) { + printf("%s size %u, {%u %u %u %u %u %u}\n", s, nsize, a[0], a[1], a[2], a[3], a[4], a[5]); + }); + + static_if{}([&](auto) { + printf("%s size %u, {%u %u %u %u %u %u %u}\n", + s, + nsize, + a[0], + a[1], + a[2], + a[3], + a[4], + a[5], + a[6]); + }); + + static_if{}([&](auto) { + printf("%s size %u, {%u %u %u %u %u %u %u %u}\n", + s, + nsize, + a[0], + a[1], + a[2], + a[3], + a[4], + a[5], + a[6], + a[7]); + }); + + static_if{}([&](auto) { + printf("%s size %u, {%u %u %u %u %u %u %u %u %u}\n", + s, + nsize, + a[0], + a[1], + a[2], + a[3], + a[4], + a[5], + a[6], + a[7], + a[8]); + }); + + static_if{}([&](auto) { + printf("%s size %u, {%u %u %u %u %u %u %u %u %u %u}\n", + s, + nsize, + a[0], + a[1], + a[2], + a[3], + a[4], + a[5], + a[6], + a[7], + a[8], + a[9]); + }); +} + +} // namespace ck +#endif \ No newline at end of file diff --git a/composable_kernel/include/utility/common_header.hpp b/composable_kernel/include/utility/common_header.hpp index a1ec782c9a..902e78f25c 100644 --- a/composable_kernel/include/utility/common_header.hpp +++ b/composable_kernel/include/utility/common_header.hpp @@ -4,14 +4,19 @@ #include "config.hpp" #include "utility.hpp" #include "integral_constant.hpp" +#include "number.hpp" +#include "type.hpp" #include "tuple.hpp" #include "math.hpp" #include "vector_type.hpp" #include "Sequence.hpp" +#include "sequence_helper.hpp" #include "Array.hpp" +#include "array_helper.hpp" #include "functional.hpp" #include "functional2.hpp" #include "functional3.hpp" +#include "functional4.hpp" #if CK_USE_AMD_INLINE_ASM #include "amd_inline_asm.hpp" diff --git a/composable_kernel/include/utility/functional.hpp b/composable_kernel/include/utility/functional.hpp index e1f6b004ce..a22f8b2435 100644 --- a/composable_kernel/include/utility/functional.hpp +++ b/composable_kernel/include/utility/functional.hpp @@ -3,9 +3,11 @@ #include "integral_constant.hpp" #include "Sequence.hpp" +#include "type.hpp" namespace ck { +// TODO: right? wrong? struct forwarder { template @@ -17,7 +19,7 @@ struct forwarder struct swallow { - template + template __host__ __device__ constexpr swallow(Ts&&...) { } @@ -32,7 +34,7 @@ struct static_if { using Type = static_if; - template + template __host__ __device__ constexpr auto operator()(F f) const { // This is a trick for compiler: @@ -43,7 +45,7 @@ struct static_if return Type{}; } - template + template __host__ __device__ static constexpr auto Else(F) { return Type{}; @@ -55,13 +57,13 @@ struct static_if { using Type = static_if; - template + template __host__ __device__ constexpr auto operator()(F) const { return Type{}; } - template + template __host__ __device__ static constexpr auto Else(F f) { // This is a trick for compiler: @@ -73,5 +75,23 @@ struct static_if } }; +template +struct conditional; + +template +struct conditional +{ + using type = X; +}; + +template +struct conditional +{ + using type = Y; +}; + +template +using conditional_t = typename conditional::type; + } // namespace ck #endif diff --git a/composable_kernel/include/utility/functional2.hpp b/composable_kernel/include/utility/functional2.hpp index 52e96b90f5..df10ca1f25 100644 --- a/composable_kernel/include/utility/functional2.hpp +++ b/composable_kernel/include/utility/functional2.hpp @@ -6,6 +6,8 @@ namespace ck { +namespace detail { + template struct static_for_impl; @@ -19,6 +21,8 @@ struct static_for_impl> } }; +} // namespace detail + // F signature: F(Number) template struct static_for @@ -33,7 +37,8 @@ struct static_for template __host__ __device__ constexpr void operator()(F f) const { - static_for_impl::type>{}(f); + detail::static_for_impl::type>{}( + f); } }; diff --git a/composable_kernel/include/utility/functional3.hpp b/composable_kernel/include/utility/functional3.hpp index f1c21d7f59..540938fd10 100644 --- a/composable_kernel/include/utility/functional3.hpp +++ b/composable_kernel/include/utility/functional3.hpp @@ -8,20 +8,7 @@ namespace ck { -template -struct is_static : integral_constant -{ -}; - -template -struct is_static> : integral_constant -{ -}; - -template -struct is_static> : integral_constant -{ -}; +namespace detail { // RemainLengths: Sequence<...> // Orders: Sequence<...> @@ -58,29 +45,6 @@ struct static_ford_impl, Orders> } }; -// Lengths is Sequence<...>, it is the length of each dimension for N-dimensional loop -// Orders is Sequence<...>, it is the order of dimension in which static_ford will loop over each -// dimension -template ::type> -struct static_ford -{ - __host__ __device__ constexpr static_ford() - { - static_assert(Lengths::GetSize() > 0, "wrong! Lengths is empty"); - static_assert(Lengths::GetSize() == Orders::GetSize(), "wrong! inconsistent size"); - } - - // F signature: F(Sequence<...> multi_id) - // multi_id is the unordered multi-index - template - __host__ __device__ constexpr void operator()(F f) const - { - constexpr auto ordered_lengths = Lengths::ReorderGivenNew2Old(Orders{}); - static_ford_impl{}(f, Sequence<>{}); - } -}; - // RemainLengths: Sequence<...> // Orders: Sequence<...> template @@ -117,6 +81,31 @@ struct ford_impl, Orders> } }; +} // namespace detail + +// Lengths is Sequence<...>, it is the length of each dimension for N-dimensional loop +// Orders is Sequence<...>, it is the order of dimension in which static_ford will loop over each +// dimension +template ::type> +struct static_ford +{ + __host__ __device__ constexpr static_ford() + { + static_assert(Lengths::GetSize() > 0, "wrong! Lengths is empty"); + static_assert(Lengths::GetSize() == Orders::GetSize(), "wrong! inconsistent size"); + } + + // F signature: F(Sequence<...> multi_id) + // multi_id is the unordered multi-index + template + __host__ __device__ constexpr void operator()(F f) const + { + constexpr auto ordered_lengths = Lengths::ReorderGivenNew2Old(Orders{}); + detail::static_ford_impl{}(f, Sequence<>{}); + } +}; + // Lengths is Sequence<...>, it is the length of each dimension for N-dimensional loop // Orders is Sequence<...>, it is the order of dimension in which ford will loop over each // dimension @@ -139,7 +128,8 @@ struct ford for(index_t i = 0; i < ordered_lengths.Front(); ++i) { - ford_impl{}(f, Array{i}); + detail::ford_impl{}(f, + Array{i}); } } }; diff --git a/composable_kernel/include/utility/functional4.hpp b/composable_kernel/include/utility/functional4.hpp new file mode 100644 index 0000000000..2cbc94ea8b --- /dev/null +++ b/composable_kernel/include/utility/functional4.hpp @@ -0,0 +1,34 @@ +#ifndef CK_FUNCTIONAL4_HPP +#define CK_FUNCTIONAL4_HPP + +#include "Sequence.hpp" +#include "tuple.hpp" +#include "Array.hpp" + +namespace ck { + +namespace detail { + +template +struct unpack_impl; + +template +struct unpack_impl> +{ + template + __host__ __device__ constexpr auto operator()(F f, const X& x) const + { + return f(x.At(Number{})...); + } +}; + +} // namespace detail + +template +__host__ __device__ constexpr auto unpack(F f, const X& x) +{ + return detail::unpack_impl::type>{}(f, x); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/integral_constant.hpp b/composable_kernel/include/utility/integral_constant.hpp index cae52ebe3a..14f3df894b 100644 --- a/composable_kernel/include/utility/integral_constant.hpp +++ b/composable_kernel/include/utility/integral_constant.hpp @@ -13,54 +13,5 @@ struct integral_constant __host__ __device__ constexpr value_type operator()() const noexcept { return value; } }; -template -struct is_same : public integral_constant -{ -}; - -template -struct is_same : public integral_constant -{ -}; - -template -using remove_cv_t = typename std::remove_cv::type; - -template -using Number = integral_constant; - -template -__host__ __device__ constexpr auto operator+(Number, Number) -{ - return Number{}; -} - -template -__host__ __device__ constexpr auto operator-(Number, Number) -{ - static_assert(Y <= X, "wrong!"); - return Number{}; -} - -template -__host__ __device__ constexpr auto operator*(Number, Number) -{ - return Number{}; -} - -template -__host__ __device__ constexpr auto operator/(Number, Number) -{ - static_assert(Y > 0, "wrong!"); - return Number{}; -} - -template -__host__ __device__ constexpr auto operator%(Number, Number) -{ - static_assert(Y > 0, "wrong!"); - return Number{}; -} - } // namespace ck #endif diff --git a/composable_kernel/include/utility/math.hpp b/composable_kernel/include/utility/math.hpp index 9e987df11f..7d7252cd4d 100644 --- a/composable_kernel/include/utility/math.hpp +++ b/composable_kernel/include/utility/math.hpp @@ -104,6 +104,18 @@ __host__ __device__ constexpr T lcm(T x, Ts... xs) return max(x, xs...); } +template +struct equal +{ + __host__ __device__ constexpr bool operator()(T x, T y) const { return x == y; } +}; + +template +struct less +{ + __host__ __device__ constexpr bool operator()(T x, T y) const { return x < y; } +}; + } // namespace math } // namspace ck diff --git a/composable_kernel/include/utility/number.hpp b/composable_kernel/include/utility/number.hpp new file mode 100644 index 0000000000..f8c5643694 --- /dev/null +++ b/composable_kernel/include/utility/number.hpp @@ -0,0 +1,44 @@ +#ifndef CK_NUMBER_HPP +#define CK_NUMBER_HPP + +#include "integral_constant.hpp" + +namespace ck { + +template +using Number = integral_constant; + +template +__host__ __device__ constexpr auto operator+(Number, Number) +{ + return Number{}; +} + +template +__host__ __device__ constexpr auto operator-(Number, Number) +{ + static_assert(Y <= X, "wrong!"); + return Number{}; +} + +template +__host__ __device__ constexpr auto operator*(Number, Number) +{ + return Number{}; +} + +template +__host__ __device__ constexpr auto operator/(Number, Number) +{ + static_assert(Y > 0, "wrong!"); + return Number{}; +} + +template +__host__ __device__ constexpr auto operator%(Number, Number) +{ + static_assert(Y > 0, "wrong!"); + return Number{}; +} +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/sequence_helper.hpp b/composable_kernel/include/utility/sequence_helper.hpp new file mode 100644 index 0000000000..c499d6e9a2 --- /dev/null +++ b/composable_kernel/include/utility/sequence_helper.hpp @@ -0,0 +1,46 @@ +#ifndef CK_SEQUENCE_HELPER_HPP +#define CK_SEQUENCE_HELPER_HPP + +#include "Sequence.hpp" + +namespace ck { + +template +__host__ __device__ void print_Sequence(const char* s, Sequence) +{ + constexpr index_t nsize = Sequence::Size(); + + static_assert(nsize <= 10, "wrong!"); + + static_if{}([&](auto) { printf("%s size %u, {}\n", s, nsize, Xs...); }); + + static_if{}([&](auto) { printf("%s size %u, {%u}\n", s, nsize, Xs...); }); + + static_if{}([&](auto) { printf("%s size %u, {%u %u}\n", s, nsize, Xs...); }); + + static_if{}([&](auto) { printf("%s size %u, {%u %u %u}\n", s, nsize, Xs...); }); + + static_if{}([&](auto) { printf("%s size %u, {%u %u %u %u}\n", s, nsize, Xs...); }); + + static_if{}( + [&](auto) { printf("%s size %u, {%u %u %u %u %u}\n", s, nsize, Xs...); }); + + static_if{}( + [&](auto) { printf("%s size %u, {%u %u %u %u %u %u}\n", s, nsize, Xs...); }); + + static_if{}( + [&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u}\n", s, nsize, Xs...); }); + + static_if{}( + [&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u %u}\n", s, nsize, Xs...); }); + + static_if{}( + [&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u %u %u}\n", s, nsize, Xs...); }); + + static_if{}( + [&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u %u %u %u}\n", s, nsize, Xs...); }); +} + +} // namespace ck + +#endif diff --git a/composable_kernel/include/utility/tuple.hpp b/composable_kernel/include/utility/tuple.hpp index 019430cd92..3fa9d0fccd 100644 --- a/composable_kernel/include/utility/tuple.hpp +++ b/composable_kernel/include/utility/tuple.hpp @@ -2,6 +2,7 @@ #define CK_TUPLE_HPP #include "integral_constant.hpp" +#include "type.hpp" #include "Sequence.hpp" namespace ck { @@ -16,6 +17,8 @@ struct TupleElementKey template struct TupleElement { + __host__ __device__ explicit constexpr TupleElement() : mData() {} + template __host__ __device__ explicit constexpr TupleElement(T&& v) : mData(static_cast(v)) { @@ -48,6 +51,12 @@ struct TupleImpl; template struct TupleImpl, Xs...> : TupleElement, Xs>... { +#if 1 + __host__ __device__ explicit constexpr TupleImpl() : TupleElement, Xs>()... + { + } +#endif + template __host__ __device__ explicit constexpr TupleImpl(Ys&&... ys) : TupleElement, Xs>(static_cast(ys))... @@ -97,5 +106,28 @@ struct Tuple : detail::TupleImpl +__host__ __device__ constexpr auto make_tuple(Xs&&... xs) +{ + return Tuple>...>(std::forward(xs)...); +} + +namespace detail { + +template +__host__ __device__ constexpr auto transpose_tuple_impl(X& x, F f, Sequence) +{ + return make_tuple(f(x.At(Number{}))...); +} + +} // namespace detail + +template +__host__ __device__ constexpr auto transpose_tuple(X& x, F f) +{ + return detail::transpose_tuple_impl( + x, f, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{}); +} + } // namespace ck #endif diff --git a/composable_kernel/include/utility/type.hpp b/composable_kernel/include/utility/type.hpp new file mode 100644 index 0000000000..98a4640027 --- /dev/null +++ b/composable_kernel/include/utility/type.hpp @@ -0,0 +1,41 @@ +#ifndef CK_TYPE_HPP +#define CK_TYPE_HPP + +#include "integral_constant.hpp" +#include "Sequence.hpp" + +namespace ck { + +template +struct is_same : public integral_constant +{ +}; + +template +struct is_same : public integral_constant +{ +}; + +template +struct is_static : integral_constant +{ +}; + +template +struct is_static> : integral_constant +{ +}; + +template +struct is_static> : integral_constant +{ +}; + +template +using remove_reference_t = typename std::remove_reference::type; + +template +using remove_cv_t = typename std::remove_cv::type; + +} // namespace ck +#endif diff --git a/driver/include/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn_padded.hpp b/driver/include/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn_padded.hpp index 9649caea0a..857924db25 100644 --- a/driver/include/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn_padded.hpp +++ b/driver/include/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn_padded.hpp @@ -115,8 +115,12 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn_padded(InDesc, constexpr index_t OutThreadCopyDataPerAccess_N = 4; #endif +#if 0 // debug constexpr index_t GridSize = (N / NPerBlock) * (K / KPerBlock) * (Ho / HoPerBlock) * (Wo / WoPerBlock); +#else + constexpr index_t GridSize = 1; +#endif printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); diff --git a/driver/src/driver.cpp b/driver/src/driver.cpp index 7304647362..e25426d812 100644 --- a/driver/src/driver.cpp +++ b/driver/src/driver.cpp @@ -73,19 +73,19 @@ int main(int argc, char* argv[]) using namespace ck; #if 1 - constexpr index_t N = 64; - constexpr index_t C = 1536; - constexpr index_t HI = 8; - constexpr index_t WI = 8; - constexpr index_t K = 256; + constexpr index_t N = 10; + constexpr index_t C = 10; + constexpr index_t HI = 10; + constexpr index_t WI = 10; + constexpr index_t K = 10; constexpr index_t Y = 1; constexpr index_t X = 1; using ConvStrides = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>; - constexpr index_t HPad = 0; - constexpr index_t WPad = 0; + constexpr index_t HPad = 2; + constexpr index_t WPad = 2; #elif 1 // 3x3, 34x34 constexpr index_t N = 64;