mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 12:00:07 +00:00
@@ -85,24 +85,35 @@ struct ConstantTensorDescriptor
|
||||
|
||||
__host__ __device__ static constexpr index_t GetElementSize()
|
||||
{
|
||||
return accumulate_on_sequence(Lengths{}, mod_conv::multiplies<index_t>{}, Number<1>{});
|
||||
return accumulate_on_sequence(Lengths{}, std::multiplies<index_t>{}, Number<1>{});
|
||||
}
|
||||
|
||||
#if 0
|
||||
// c++14 doesn't support constexpr lambdas, has to use this trick instead
|
||||
struct GetElementSpace_f
|
||||
struct f_GetElementSpace_impl
|
||||
{
|
||||
template <class IDim>
|
||||
__host__ __device__ constexpr index_t operator()(IDim idim) const
|
||||
{
|
||||
return (Type{}.GetLength(idim) - 1) * Type{}.GetStride(idim);
|
||||
}
|
||||
__host__ __device__ constexpr index_t operator()(index_t length, index_t stride) const
|
||||
{
|
||||
return (length - 1) * stride;
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
template <class Align = Number<1>>
|
||||
__host__ __device__ static constexpr index_t GetElementSpace(Align align = Align{})
|
||||
{
|
||||
#if 0
|
||||
index_t element_space_unaligned =
|
||||
static_const_reduce_n<nDim>{}(GetElementSpace_f{}, mod_conv::plus<index_t>{}) + 1;
|
||||
static_const_reduce_n<nDim>{}(f_GetElementSpace_impl{}, std::plus<index_t>{}) + 1;
|
||||
#else
|
||||
constexpr index_t element_space_unaligned = accumulate_on_sequence(
|
||||
(GetLengths() - Number<1>{}) * GetStrides(), std::plus<index_t>{}, Number<1>{});
|
||||
#endif
|
||||
|
||||
return align.Get() * ((element_space_unaligned + align.Get() - 1) / align.Get());
|
||||
}
|
||||
@@ -140,9 +151,9 @@ struct ConstantTensorDescriptor
|
||||
constexpr auto multi_id = Sequence<Is...>{};
|
||||
|
||||
constexpr auto seq_tmp =
|
||||
transform_sequences(mod_conv::multiplies<index_t>{}, multi_id, GetStrides());
|
||||
transform_sequences(std::multiplies<index_t>{}, multi_id, GetStrides());
|
||||
|
||||
return accumulate_on_sequence(seq_tmp, mod_conv::plus<index_t>{}, Number<0>{});
|
||||
return accumulate_on_sequence(seq_tmp, std::plus<index_t>{}, Number<0>{});
|
||||
}
|
||||
|
||||
__host__ __device__ static Array<index_t, nDim> GetMultiIndex(index_t id)
|
||||
@@ -167,34 +178,112 @@ struct ConstantTensorDescriptor
|
||||
}
|
||||
|
||||
template <index_t IDims...>
|
||||
__host__ __device__ static constexpr auto Extract(Number<IDims>... /*extracted_dims...*/)
|
||||
__host__ __device__ static constexpr auto Extract(Number<IDims>... extract_dims)
|
||||
{
|
||||
static_assert(sizeof...(IDims) <= GetNumOfDimension(), "wrong!");
|
||||
static_assert(sizeof...(IDims) <= GetNumOfDimension(),
|
||||
"wrong! too many number of dimensions to be extracted");
|
||||
|
||||
constexpr auto extracted_lengths = Sequence<Lengths{}.Get(Number<IDims>{})...>{};
|
||||
constexpr auto extracted_strides = Sequence<Strides{}.Get(Number<IDims>{})...>{};
|
||||
|
||||
return make_ConstantTensorDescriptor(extracted_lenghts, extracted_strides);
|
||||
return make_ConstantTensorDescriptor(Lengths{}.Extract(extract_dims),
|
||||
Strides{}.Extract(extract_dims));
|
||||
}
|
||||
|
||||
template <index_t IDim, index_t SliceLen>
|
||||
__host__ __device__ static constexpr auto Slice(Number<IDim>, Number<SliceLen>)
|
||||
{
|
||||
// not implemented
|
||||
return make_ConstantTensorDescriptor(Lengths{}.Modify(Number<IDim>{}, Number<SliceLen>{}),
|
||||
Strides{});
|
||||
}
|
||||
|
||||
template <index_t IDim, index_t... FoldLengths>
|
||||
__host__ device__ static constexpr auto Fold(Number<IDim>, Sequence<FoldLengths...>)
|
||||
template <index_t IDim, index_t... FoldIntervals>
|
||||
__host__ device__ static constexpr auto Fold(Number<IDim>, Number<FoldIntervals>...)
|
||||
{
|
||||
// not implemented
|
||||
// need to check the Length dimension to be folded is dividable by FoldLengths
|
||||
constexpr auto fold_intervals = Sequence<FoldIntervals...>{};
|
||||
|
||||
constexpr fold_intervals_product =
|
||||
accumulate_on_sequence(fold_intervals, std::multiplies<index_t>{}, Number<1>{});
|
||||
|
||||
constexpr auto unfold_length = GetLength(Number<IDim>{});
|
||||
constexpr auto unfold_stride = GetStride(Number<IDim>{});
|
||||
|
||||
// length of the dimension to be folded needs to be dividable by fold_interval_product,
|
||||
// otherwise, folding is invalid
|
||||
static_assert(unfold_length % fold_interval_product == 0,
|
||||
"wrong! length on the dimension to be folded cannot be evenly divided!");
|
||||
|
||||
// folded lengths
|
||||
constexpr auto fold_lengths =
|
||||
Sequence<unfold_length / fold_interval_product>{}.Append(fold_intervals);
|
||||
|
||||
// folded strides
|
||||
constexpr auto fold_strides = transform_sequences(mod_conv::scales<index_t, unfold_stride>{},
|
||||
reverse_scan_sequence(fold_intervals.PushBack(Number<1>{}), std::multiplies<index_t>{});
|
||||
|
||||
// left and right lengths
|
||||
constexpr auto lengths_pair = GetLengths().Split(Number<I>{});
|
||||
constexpr auto left_lengths = lengths_pair.first;
|
||||
constexpr auto right_lengths = lengths_pair.second.PopFront();
|
||||
|
||||
// left and right strides
|
||||
constexpr auto strides_pair = GetStrides().Split(Number<I>{});
|
||||
constexpr auto left_strides = strides_pair.first;
|
||||
constexpr auto right_strides = strides_pair.second.PopFront();
|
||||
|
||||
return make_ConstantTensorDescriptor(left_lengths.Append(fold_lengths).Append(right_lengths),
|
||||
left_strides.Append(fold_strides).Append(right_strides));
|
||||
}
|
||||
|
||||
template <index_t FirstUnfoldDim, index_t LastUnfoldDim>
|
||||
__host__ __device__ static constexpr auto Unfold(Number<FirstUnfoldDim>, Number<LastUnfoldDim>)
|
||||
{
|
||||
// not implemented
|
||||
// need to check the dimensions to be unfold are packed, otherwise, Unfold is not permitted
|
||||
static_assert(FirstUnfoldDim >= 0 && LastUnfoldDim < nDim &&
|
||||
FirstUnfoldDim <= LastUnfoldDim,
|
||||
"wrong! should have FirstUnfoldDim <= LastUnfoldDim!");
|
||||
|
||||
// dimensions to be unfold need to be in descending order (w.r.t. strides), and need to be
|
||||
// packed in memory, otherwise, unfolding is invalid
|
||||
static_for<FirstUnfoldDim, LastUnfoldDim, 1>{}([&](auto IDim) {
|
||||
static_assert(
|
||||
GetStride(IDim) >= GetStride(Number<IDim.Get() + 1>{}),
|
||||
"wrong! dimensions to be unfolded need to be in descending order w.r.t strides");
|
||||
|
||||
static_assert(GetStride(IDim + 1) * GetLength(IDim + 1) == GetStride(IDim),
|
||||
"wrong! dimensions to be unfolded need to be packed");
|
||||
});
|
||||
|
||||
// lengths
|
||||
constexpr auto lens_pair1 = Lengths{}.Split(Number<LastUnfoldDim + 1>{});
|
||||
|
||||
constexpr auto right_lengths = lens_pair1.second;
|
||||
|
||||
constexpr auto lens_pair2 = lens_pair1.first.Split(Number<FirstUnfoldDim>{});
|
||||
|
||||
constexpr auto left_lengths = lens_pair2.first;
|
||||
|
||||
constexpr auto fold_lengths = lens_pair2.second;
|
||||
|
||||
constexpr index_t unfold_length =
|
||||
accumulate_on_sequence(fold_lengths, std::multiplies<index_t>{}, Number<1>{});
|
||||
|
||||
constexpr auto new_strides =
|
||||
left_strides.PopBack(Number<unfold_strides>{}).Append(right_strides);
|
||||
|
||||
// strides
|
||||
constexpr auto strides_pair1 = Strides{}.Split(Number<LastUnfoldDim + 1>{});
|
||||
|
||||
constexpr auto right_strides = strides_pair1.second;
|
||||
|
||||
constexpr auto strides_pair2 = strides_pair1.first.Split(Number<FirstUnfoldDim>{});
|
||||
|
||||
constexpr auto left_strides = strides_pair2.first;
|
||||
|
||||
constexpr auto fold_strides = strides_pair2.second;
|
||||
|
||||
constexpr index_t unfold_stride = fold_strides.Back();
|
||||
|
||||
constexpr auto new_strides =
|
||||
left_strides.PushBack(Number<unfold_strides>{}).Append(right_strides);
|
||||
|
||||
return make_ConstantTensorDescriptor(new_lengths, new_strides);
|
||||
}
|
||||
|
||||
template <index_t... IRs>
|
||||
|
||||
Reference in New Issue
Block a user