diff --git a/composable_kernel/include/tensor_description/dimension_transform.hpp b/composable_kernel/include/tensor_description/dimension_transform.hpp deleted file mode 100644 index f8b513ef19..0000000000 --- a/composable_kernel/include/tensor_description/dimension_transform.hpp +++ /dev/null @@ -1,217 +0,0 @@ -#ifndef CK_DIMENSION_TRANSFORM_HPP -#define CK_DIMENSION_TRANSFORM_HPP - -#include "common_header.hpp" - -namespace ck { - -template -using MultiIndex = Array; - -// LowLengths: Sequence<...> -template -struct PassThrough -{ - static constexpr index_t nDim = LowLengths::GetSize(); - - using LowerId = MultiIndex; - using UpperId = LowerId; - - __host__ __device__ static constexpr auto GetLowerNumOfDimension() { return Number{}; } - - __host__ __device__ static constexpr auto GetUpperNumOfDimension() - { - return GetLowerNumOfDimension(); - } - - __host__ __device__ static constexpr auto GetLowerLengths() { return LowLengths{}; } - - __host__ __device__ static constexpr auto GetUpperLengths() { return GetLowerLengths(); } - - __host__ __device__ static constexpr auto GetLowerId(UpperId id_up) { return id_up; } - - __host__ __device__ static constexpr auto GetLowerIdDiff(UpperId id_up_diff) - { - return id_up_diff; - } - - __host__ __device__ static constexpr bool IsLinearTransform() { return true; } -}; - -// LowLengths: Sequence<...> -template -struct Pad -{ - static constexpr index_t nDim = LowLengths::GetSize(); - - using LowerId = MultiIndex; - using UpperId = LowerId; - - __host__ __device__ static constexpr auto GetLowerNumOfDimension() { return Number{}; } - - __host__ __device__ static constexpr auto GetUpperNumOfDimension() - { - return GetLowerNumOfDimension(); - } - - __host__ __device__ static constexpr auto GetLowerLengths() { return LowLengths{}; } - - __host__ __device__ static constexpr auto GetUpperLengths() - { - return GetLowerLengths() + LeftPads + RightPads; - } - - __host__ __device__ static constexpr auto GetLowerId(UpperId id_up) { return id_up - LeftPads; } - - __host__ __device__ static constexpr auto GetLowerIdDiff(UpperId id_up_diff) - { - return id_up_diff; - } - - __host__ __device__ static constexpr bool IsLinearTransform() { return true; } -}; - -// LowLengths: Sequence<...> -template -struct Merge -{ - static constexpr index_t nDimLow = LowLengths::GetSize(); - static constexpr index_t nDimUp = 1; - - using LowerId = MultiIndex; - using UpperId = MultiIndex; - - __host__ __device__ static constexpr auto GetUpperNumOfDimension(){return Number{}}; - - __host__ __device__ static constexpr auto GetLowerNumOfDimension() { return Number{}; } - - __host__ __device__ static constexpr auto GetLowerLengths() { return LowLengths{}; } - - __host__ __device__ static constexpr auto GetUpperLengths() - { - return Sequence{}, Number<1>{})>{}; - } - - __host__ __device__ static constexpr auto GetLowerId(UpperId id_up) - { - LowerId id_low; - - // not implemeneted - - return id_low; - } - - // id_low_diff depends on id_low_old, so id_low need to be up-to-date - __host__ __device__ static constexpr auto GetLowerIdDiff(UpperId id_up_diff, LowerId id_low_old) - { - LowerId id_low_diff; - - // not implemeneted - - return id_low_diff; - } - - __host__ __device__ static constexpr bool IsLinearTransform() { return false; } -}; - -// UpLengths: Sequence<...> -template -struct Unmerge -{ - static constexpr index_t nDimLow = 1; - static constexpr index_t nDimUp = UpLengths::GetSize(); - - __host__ __device__ constexpr Unmerge() - { - static_assert(LowLength == accumulate_on_sequence( - UpLengths{}, math::multiplies{}, Number<1>{}), - "wrong! UpLengths need to be "); - } - - __host__ __device__ static constexpr auto GetUpperNumOfDimension(){return Number{}}; - - __host__ __device__ static constexpr auto GetLowerNumOfDimension() { return Number{}; } - - __host__ __device__ static constexpr auto GetLowerLengths() { return Sequence{}; } - - __host__ __device__ static constexpr auto GetUpperLengths() { return UpLengths{}; } - - __host__ __device__ static constexpr auto GetLowerId(UpperId id_up) - { - constexpr auto scans = typename sequence_reverse_inclusive_scan, - 1>::type{}; - - LowerId id_low{0}; - - static_for<0, nDim, 1>{}([&](auto idim) { id_low[0] += id_up[idim] * scans[idim]; }); - - return id_low; - } - - __host__ __device__ static constexpr auto GetLowerIdDiff(UpperId id_up_diff) - { - return GetLowerId(id_up_diff); - } - - __host__ __device__ static constexpr bool IsLinearTransform() { return true; } -}; - -// UpLengths: Sequence<...> -// Coefficients: Sequence<...> -// id_low = coefficients[0, ...nDimUp-1] * id_up[0, ...nDimUp-1] + coefficients[nDimUp] -template -struct Embed -{ - static constexpr index_t nDimLow = 1; - static constexpr index_t nDimUp = UpLengths::GetSize(); - - static constexpr auto mCoefficients = Coefficients{}; - - __host__ __device__ constexpr Embed() - { - static_assert(UpLengths::GetSize() == nDimUp && Coefficients::GetSize() == nDimUp + 1, - "wrong! # of dimensions not consistent"); - - constexpr index_t low_id_max = - Coefficents.Back() + accumulate_on_sequence(UpLengths{} * Coefficients::PopBack(), - math::plus{}, - Number<0>{}); - - static_assert(low_id_max < LowLength, "wrong! lower-id will go out of range"); - } - - __host__ __device__ static constexpr auto GetUpperNumOfDimension(){return Number{}}; - - __host__ __device__ static constexpr auto GetLowerNumOfDimension() { return Number{}; } - - __host__ __device__ static constexpr auto GetLowerLengths() { return Sequence{}; } - - __host__ __device__ static constexpr auto GetUpperLengths() { return UpLengths{}; } - - __host__ __device__ static constexpr auto GetLowerId(UpperId id_up) - { - LowerId id_low{mCoefficients[nDimUp]}; - - static_for<0, nDimUp, 1>{}( - [&](auto idim) { id_low[0] += id_up[idim] * mCoefficients[idim]; }); - - return id_low; - } - - __host__ __device__ static constexpr auto GetLowerIdDiff(UpperId id_up_diff) - { - LowerId id_low_diff{0}; - - static_for<0, nDimUp, 1>{}( - [&](auto idim) { id_low_diff[0] += id_up_diff[idim] * mCoefficients[idim]; }); - - return id_low_diff; - } - - __host__ __device__ static constexpr bool IsLinearTransform() { return true; } -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/tensor_description/tensor_descriptor.hpp b/composable_kernel/include/tensor_description/tensor_descriptor.hpp index 62f23b2bbd..b2467e7e0c 100644 --- a/composable_kernel/include/tensor_description/tensor_descriptor.hpp +++ b/composable_kernel/include/tensor_description/tensor_descriptor.hpp @@ -3,75 +3,159 @@ #include "common_header.hpp" #include "dimension.hpp" +#include "multi_index_transform.hpp" namespace ck { -template +template struct NativeTensorDescriptor { - using type = NativeTensorDescriptor; - static constexpr index_t nDim = Lengths::GetSize(); + using type = NativeTensorDescriptor; + static constexpr auto mDimensions = Tuple; + static constexpr index_t nDim = mDimensions::GetSize(); - using Id = MultiIndex; + using Index = MultiIndex; __host__ __device__ static constexpr auto GetNumOfDimension() { return Number{}; } - __host__ __device__ static constexpr auto GetLengths() { return Lengths{}; } - - __host__ __device__ static constexpr auto GetStrides() { return Strides{}; } - - __host__ __device__ static constexpr auto GetLength(index_t IDim) { return Lengths{}[IDim]; } - - __host__ __device__ static constexpr auto GetStride(index_t IDim) { return Strides{}[IDim]; } - - __host__ __device__ static constexpr index_t GetOffset(Id id) + __host__ __device__ static constexpr auto GetLengths() { // not implemented } + + __host__ __device__ static constexpr auto GetStrides() + { + // not implemented + } + + template + __host__ __device__ static constexpr auto GetLength(Number) + { + return mDimensions.Get(Number{}).GetLength(); + } + + template + __host__ __device__ static constexpr auto GetStride(Number) + { + return mDimensions.Get(Number{}).GetStride(); + } + + __host__ __device__ static constexpr index_t GetOffset(Index idx) + { + index_t offset = 0; + + static_for<0, nDim, 1>{}([&](auto idim) { offset += idx[idim] * GetStride(idim); }); + + return offset; + } + + __host__ __device__ static constexpr index_t GetOffsetDiff(Index idx_diff) + { + index_t offset_diff = 0; + + static_for<0, nDim, 1>{}( + [&](auto idim) { offset_diff += idx_diff[idim] * GetStride(idim); }); + + return offset_diff; + } + + __host__ __device__ static constexpr auto AreUpperIndex2OffsetTransformLinear(); + { + // TODO: re-implement "Sequence", so that it can take other data-type (including bool) as + // element + return uniform_sequence_gen{}; + } + + __host__ __device__ static constexpr auto GetIndependentDimensionGroups() + { + // not implemented, should return Tuple, Sequence<1>, ...> + return xxx; + } }; // LowerTensorDescriptor // Transforms: std::tuple -// LowerIds: std::tuple> -// UpperIds: std::tuple> -template +// LowerDimensionIds: std::tuple> +// UpperDimensionIds: std::tuple> +template struct TransformedTensorDescriptor { using type = TransformedTensorDescriptor; - static constexpr index_t nDimUp = xxxx; - static constexpr index_t nDimLow = xxx; + static constexpr index_t nDimUp = GetUpperNumOfDimension(); + static constexpr index_t nDimLow = GetLowerNumOfDimension(); static constexpr index_t nTransform = Transforms::GetSize(); - using UpperId = MultiIndex; - using LowerId = MultiIndex; + using UpperIndex = MultiIndex; + using LowerIndex = MultiIndex; __host__ __device__ static constexpr TransformedTensorDescriptor() { static_assert(nTransform == Transforms::GetSize() && - nTransform == LowDimensionMasks::GetSize() && - nTransform == UpDimensionMasks::GetSize(), + nTransform == LowDimensionIds::GetSize() && + nTransform == UpDimensionIds::GetSize(), "wrong! # of transformations not the same"); - // TODO: sanity check: LowDimensionMasks should include all low-dimensions, - // UpDimensionMasks should include all up-dimensions + // TODO: sanity check: LowDimensionIds should include all low-dimensions, + // UpDimensionIds should include all up-dimensions // TODO: sanity check: while a up-dimension could be associated with multille // transformation, // a low-dimension should be associated with only one transformation } + __host__ __device__ static constexpr auto GetNumOfLowerDimension() + { + // Here, we assume all lower-dimensions are active + // TODO: sanity-check all lower-dimension are indeed active + constexpr auto low_active_dims = unique_sort_sequence( + merge_tuple_of_sequences(LowDimensionIds{}), math::less{}); + + return low_active_dims.GetSize(); + } + + __host__ __device__ static constexpr auto GetNumOfUpperDimension() + { + constexpr auto up_active_dims = + unique_sort_sequence(merge_tuple_of_sequences(UpDimensionIds{}), math::less{}); + return up_active_dims.GetSize(); + } + __host__ __device__ static constexpr auto GetNumOfDimension() { - // not implemented + return GetNumOfUpperDimension(); } __host__ __device__ static constexpr auto GetLengths() { - // not implemented + struct lambda_get_upper_lengths + { + template + __host__ __device__ constexpr auto operator()(Transform tran) const + { + return tran.GetUpperLengths(); + } + }; + + constexpr auto tuple_of_upper_lengths = + transform_tuple(Transforms, lambda_get_upper_lengths{}); + + constexpr auto all_upper_lengths = merge_tuple_of_sequences(tuple_of_upper_lengths); + + constexpr auto all_upper_dimension_ids = merge_tuple_of_sequences(UpDimensionIds{}); + + // TODO: sanity-check all_upper_dimension_ids contain all upper-dimensions + // TODO: sanity-check all_upper_lengths have no conflicting upper-length + + using sort_dimension_ids = + sequence_unique_sort>; + constexpr auto sorted_upper_dimension_ids = typename sort_dimension_ids::type; + constexpr auto sorted2unsorted_map = typename sort_dimension_ids::sorted2unsorted_map_type; + + constexpr auto sorted_upper_lengths = + sequence_element_pick(all_upper_lengths, sorted2unsorted_map); + + return sorted_upper_lengths; } __host__ __device__ static constexpr auto GetLowerTensorDescriptor() @@ -79,17 +163,57 @@ struct TransformedTensorDescriptor return LowTensorDescriptor{}; } - __host__ __device__ static constexpr index_t GetLowerId(UpperId id_up) + __host__ __device__ static constexpr index_t GetLowerIndex(UpperIndex idx_up) { - // not implemented + LowerIndex idx_low; + + static_for<0, nTransform, 1>{}([&](auto itran) { + constexpr auto tran = Transforms::Get(itran); + + constexpr auto idx_low_part = pick_array_element(idx_low, LowDimensionIds::Get(itran)); + constexpr auto idx_up_part = pick_array_element(idx_up, UpDimensionIds::Get(itran)); + + // this assume each lower (single) index is only assocaited with one transformation, + // which is required for index transformation, and has been checked during constructor + // of TransformedTensorDescriptor + idx_low_part = tran.GetLowerIndex(idx_up_part); + }); + + return idx_low; } - __host__ __device__ static constexpr index_t GetOffset(UpperId id_up) + __host__ __device__ static constexpr index_t GetLowerIndexDiff(UpperIndex idx_up_diff, + LowerIndex idx_low_old) { - return GetLowerTensorDescriptor().GetOffset(GetLowerId(id_up)); + LowerIndex idx_low_diff; + + static_for<0, nTransform, 1>{}([&](auto itran) { + constexpr auto tran = Transforms::Get(itran); + + constexpr auto idx_up_diff_part = + pick_array_element(idx_up_diff, UpDimensionIds::Get(itran)); + + constexpr auto idx_low_diff_part = + pick_array_element(idx_low_diff, LowDimensionIds::Get(itran)); + + constexpr auto idx_low_old_part = + pick_array_element(idx_low_old, LowDimensionIds::Get(itran)); + + // this assume each lower (single) index is associated with only one transformation, + // which is required for index transformation, and has been checked during constructor + // of TransformedTensorDescriptor + idx_low_diff_part = tran.GetLowerIndex(idx_up_diff_part, idx_low_old_part); + }); + + return idx_low_diff; } - __host__ __device__ static constexpr auto AreUpperId2OffsetLinear(); + __host__ __device__ static constexpr index_t GetOffset(UpperIndex idx_up) + { + return GetLowerTensorDescriptor().GetOffset(GetLowerIndex(idx_up)); + } + + __host__ __device__ static constexpr auto AreUpperIndex2OffsetTransformLinear(); { // not implemented } diff --git a/composable_kernel/include/utility/Array.hpp b/composable_kernel/include/utility/Array.hpp index 3b7bba8429..036fbc7d9c 100644 --- a/composable_kernel/include/utility/Array.hpp +++ b/composable_kernel/include/utility/Array.hpp @@ -79,6 +79,56 @@ struct Array } }; +// A: Array +// Picks: Sequence<...> +template +ArrayElementPicker +{ + __host__ __device__ constexpr ArrayElementPicker(Arr & array) : mData{array} + { + constexpr index_t imax = + accumulate_on_sequence(Picks{}, math::maxer{}, Number<0>{}); + + static_assert(imax < Picks::GetSize(), "wrong! exceeding max id"); + } + + __host__ __device__ static constexpr index_t GetSize() { return Picks::GetSize(); } + + template + __host__ __device__ constexpr TData operator[](Number) const + { + constexpr auto IP = Picks::Get(Number{}); + return mData[IP]; + } + + __host__ __device__ constexpr TData operator[](index_t i) const + { + constexpr index_t ip = Picks{}[i]; + return mData[ip]; + } + + template + __host__ __device__ TData& operator()(Number) + { + constexpr auto IP = Picks::Get(Number{}); + return mData[IP]; + } + + __host__ __device__ TData& operator()(index_t i) + { + constexpr index_t ip = Picks{}[i]; + return mData[ip]; + } + + Arr& mData; +}; + +template +__host__ __device__ constexpr auto pick_array_element(Arr& a, Picks) +{ + return ArrayElementPicker(a); +} + template __host__ __device__ constexpr auto sequence2array(Sequence) { diff --git a/composable_kernel/include/utility/Sequence.hpp b/composable_kernel/include/utility/Sequence.hpp index 1d8467afb0..3abdceda16 100644 --- a/composable_kernel/include/utility/Sequence.hpp +++ b/composable_kernel/include/utility/Sequence.hpp @@ -6,6 +6,9 @@ namespace ck { +template +struct static_for; + template struct Sequence; @@ -294,6 +297,18 @@ struct sequence_reverse> using type = Sequence; }; +template +struct sequence_sort +{ + // not implemented +}; + +template +struct sequence_unique_sort +{ + // not implemented +}; + template struct is_valid_sequence_map { @@ -486,6 +501,35 @@ __host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce, Number{}).Reverse(); } +template +struct lambda_accumulate_on_sequence +{ + const Reduce& f; + index_t& result; + + __host__ __device__ constexpr lambda_accumulate_on_sequence(const Reduce& f_, index_t& result_) + : f(f_), result(result_) + { + } + + template + __host__ __device__ constexpr index_t operator()(IDim) const + { + return result = f(result, Seq::Get(IDim{})); + } +}; + +template +__host__ __device__ constexpr index_t +accumulate_on_sequence(Seq, Reduce f, Number /*initial_value*/) +{ + index_t result = Init; + + static_for<0, Seq::mSize, 1>{}(lambda_accumulate_on_sequence(f, result)); + + return result; +} + template __host__ __device__ void print_Sequence(const char* s, Sequence) { diff --git a/composable_kernel/include/utility/functional2.hpp b/composable_kernel/include/utility/functional2.hpp index 289b9d9b3f..52e96b90f5 100644 --- a/composable_kernel/include/utility/functional2.hpp +++ b/composable_kernel/include/utility/functional2.hpp @@ -37,34 +37,5 @@ struct static_for } }; -template -struct lambda_accumulate_on_sequence -{ - const Reduce& f; - index_t& result; - - __host__ __device__ constexpr lambda_accumulate_on_sequence(const Reduce& f_, index_t& result_) - : f(f_), result(result_) - { - } - - template - __host__ __device__ constexpr index_t operator()(IDim) const - { - return result = f(result, Seq::Get(IDim{})); - } -}; - -template -__host__ __device__ constexpr index_t -accumulate_on_sequence(Seq, Reduce f, Number /*initial_value*/) -{ - index_t result = Init; - - static_for<0, Seq::mSize, 1>{}(lambda_accumulate_on_sequence(f, result)); - - return result; -} - } // namespace ck #endif diff --git a/composable_kernel/include/utility/math.hpp b/composable_kernel/include/utility/math.hpp index 197759ad25..9e987df11f 100644 --- a/composable_kernel/include/utility/math.hpp +++ b/composable_kernel/include/utility/math.hpp @@ -31,6 +31,12 @@ struct multiplies __host__ __device__ constexpr T operator()(T a, T b) const { return a * b; } }; +template +struct maxer +{ + __host__ __device__ constexpr T operator()(T a, T b) const { return a >= b ? a : b; } +}; + template struct integer_divide_ceiler {