diff --git a/composable_kernel/include/tensor_description/tensor_descriptor.hpp b/composable_kernel/include/tensor_description/tensor_descriptor.hpp index df625c6ecb..4cff9a45d1 100644 --- a/composable_kernel/include/tensor_description/tensor_descriptor.hpp +++ b/composable_kernel/include/tensor_description/tensor_descriptor.hpp @@ -256,7 +256,7 @@ struct TransformedTensorDescriptor constexpr auto sorted2unsorted_map = typename sort_up_dimension_ids::sorted2unsorted_map{}; constexpr auto sorted_up_lengths = - pick_sequence_elements(mingled_up_lengths, sorted2unsorted_map); + pick_sequence_elements_by_ids(mingled_up_lengths, sorted2unsorted_map); return sorted_up_lengths; } @@ -347,20 +347,60 @@ struct TransformedTensorDescriptor } #if 0 + struct lambda_sequence_logic_or + { + template + __host__ __device__ constexpr auto operator()(Seqs... seqs) const + { + // TODO: should use math::logic_or, after Sequence can take bool + return typename sequence_reduce, Seqs...>::type{}; + } + }; + + struct lambda_1 + { + template + __host__ __device__ constexpr auto operator()(const Transform& tran) const + { + return tran.GetUpperLengths(); + } + }; + + template + __host__ __device__ static constexpr bool GetMaskOfLinearDimensions() + { + // create tuple of linear dimension masks, for all transformations + constexpr auto tuple_of_linear_dimension_mask = + transform_tuple(lambda_1, Transforms{}); + + // reduce tuple of masks into one mask + constexpr auto linear_dimension_mask = + unpack(lambda_sequence_logic_or{}, tuple_of_linear_dimension_mask); + + return linear_dimension_mask; + } + template __host__ __device__ static constexpr bool IsLinearDimension(Number) { - // not implemented + return GetMaskOfLinearDimensions().At(Number{}); } __host__ __device__ static constexpr auto GetLinearDimensions() { - // not implemented + constexpr auto linear_dimension_mask = GetMaskOfLienarDimensions(); + + return pick_sequence_elements_by_mask( + typename arithmetic_sequence_gen<0, nDimUp, 1>::type{}, linear_dimension_mask); } __host__ __device__ static constexpr auto GetNonLinearDimensions() { - // not implemented + constexpr auto nonlinear_dimension_mask = + GetMaskOfLienarDimensions().Transform(math::logic_not{}); + + return pick_sequence_elements_by_mask( + typename arithmetic_sequence_gen<0, nDimUp, 1>::type{}, nonlinear_dimension_mask); } __host__ __device__ static constexpr auto GetNonLinearIndependentDimensionGroups() diff --git a/composable_kernel/include/utility/sequence.hpp b/composable_kernel/include/utility/sequence.hpp index 37754cca20..c351754140 100644 --- a/composable_kernel/include/utility/sequence.hpp +++ b/composable_kernel/include/utility/sequence.hpp @@ -311,6 +311,28 @@ struct sequence_reverse> using type = Sequence; }; +#if 0 +template +struct sequence_reduce +{ + using type = typename sequence_reduce::type>::type; +}; + +template +struct sequence_reduce, Sequence> +{ + using type = Sequence; +}; + +template +struct sequence_reduce +{ + using type = Seq; +}; +#endif + template struct sequence_sort_impl { @@ -728,11 +750,19 @@ __host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce, Number -__host__ __device__ constexpr auto pick_sequence_elements(Seq, Sequence) +__host__ __device__ constexpr auto pick_sequence_elements_by_ids(Seq, Sequence /* ids */) { return Sequence{})...>{}; } +#if 0 +template +__host__ __device__ constexpr auto pick_sequence_elements_by_mask(Seq, Mask) +{ + // not implemented +} +#endif + template struct lambda_accumulate_on_sequence {