mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 09:16:52 +00:00
adding logic to judge linear dimension
This commit is contained in:
@@ -256,7 +256,7 @@ struct TransformedTensorDescriptor
|
||||
constexpr auto sorted2unsorted_map = typename sort_up_dimension_ids::sorted2unsorted_map{};
|
||||
|
||||
constexpr auto sorted_up_lengths =
|
||||
pick_sequence_elements(mingled_up_lengths, sorted2unsorted_map);
|
||||
pick_sequence_elements_by_ids(mingled_up_lengths, sorted2unsorted_map);
|
||||
|
||||
return sorted_up_lengths;
|
||||
}
|
||||
@@ -347,20 +347,60 @@ struct TransformedTensorDescriptor
|
||||
}
|
||||
|
||||
#if 0
|
||||
struct lambda_sequence_logic_or
|
||||
{
|
||||
template <typename... Seqs>
|
||||
__host__ __device__ constexpr auto operator()(Seqs... seqs) const
|
||||
{
|
||||
// TODO: should use math::logic_or<bool>, after Sequence can take bool
|
||||
return typename sequence_reduce<math::logic_or<index_t>, Seqs...>::type{};
|
||||
}
|
||||
};
|
||||
|
||||
struct lambda_1
|
||||
{
|
||||
template <typename Transform>
|
||||
__host__ __device__ constexpr auto operator()(const Transform& tran) const
|
||||
{
|
||||
return tran.GetUpperLengths();
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t IDim>
|
||||
__host__ __device__ static constexpr bool GetMaskOfLinearDimensions()
|
||||
{
|
||||
// create tuple of linear dimension masks, for all transformations
|
||||
constexpr auto tuple_of_linear_dimension_mask =
|
||||
transform_tuple(lambda_1, Transforms{});
|
||||
|
||||
// reduce tuple of masks into one mask
|
||||
constexpr auto linear_dimension_mask =
|
||||
unpack(lambda_sequence_logic_or{}, tuple_of_linear_dimension_mask);
|
||||
|
||||
return linear_dimension_mask;
|
||||
}
|
||||
|
||||
template <index_t IDim>
|
||||
__host__ __device__ static constexpr bool IsLinearDimension(Number<IDim>)
|
||||
{
|
||||
// not implemented
|
||||
return GetMaskOfLinearDimensions().At(Number<IDim>{});
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetLinearDimensions()
|
||||
{
|
||||
// not implemented
|
||||
constexpr auto linear_dimension_mask = GetMaskOfLienarDimensions();
|
||||
|
||||
return pick_sequence_elements_by_mask(
|
||||
typename arithmetic_sequence_gen<0, nDimUp, 1>::type{}, linear_dimension_mask);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetNonLinearDimensions()
|
||||
{
|
||||
// not implemented
|
||||
constexpr auto nonlinear_dimension_mask =
|
||||
GetMaskOfLienarDimensions().Transform(math::logic_not<index_t>{});
|
||||
|
||||
return pick_sequence_elements_by_mask(
|
||||
typename arithmetic_sequence_gen<0, nDimUp, 1>::type{}, nonlinear_dimension_mask);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetNonLinearIndependentDimensionGroups()
|
||||
|
||||
@@ -311,6 +311,28 @@ struct sequence_reverse<Sequence<I0, I1>>
|
||||
using type = Sequence<I1, I0>;
|
||||
};
|
||||
|
||||
#if 0
|
||||
template <typename Reduce, typename Seq, typename... Seqs>
|
||||
struct sequence_reduce
|
||||
{
|
||||
using type = typename sequence_reduce<Reduce,
|
||||
Seq,
|
||||
typename sequence_reduce<Reduce, Seqs...>::type>::type;
|
||||
};
|
||||
|
||||
template <typename Reduce, index_t... Xs, index_t... Ys>
|
||||
struct sequence_reduce<Reduce, Sequence<Xs...>, Sequence<Ys...>>
|
||||
{
|
||||
using type = Sequence<Reduce{}(Xs, Ys)...>;
|
||||
};
|
||||
|
||||
template <typename Reduce, typename Seq>
|
||||
struct sequence_reduce<Reduce, Seq>
|
||||
{
|
||||
using type = Seq;
|
||||
};
|
||||
#endif
|
||||
|
||||
template <typename Values, typename Ids, typename Compare>
|
||||
struct sequence_sort_impl
|
||||
{
|
||||
@@ -728,11 +750,19 @@ __host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce, Number<I
|
||||
}
|
||||
|
||||
template <typename Seq, index_t... Is>
|
||||
__host__ __device__ constexpr auto pick_sequence_elements(Seq, Sequence<Is...>)
|
||||
__host__ __device__ constexpr auto pick_sequence_elements_by_ids(Seq, Sequence<Is...> /* ids */)
|
||||
{
|
||||
return Sequence<Seq::At(Number<Is>{})...>{};
|
||||
}
|
||||
|
||||
#if 0
|
||||
template <typename Seq, typename Mask>
|
||||
__host__ __device__ constexpr auto pick_sequence_elements_by_mask(Seq, Mask)
|
||||
{
|
||||
// not implemented
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename Seq, typename Reduce>
|
||||
struct lambda_accumulate_on_sequence
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user