mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 17:26:00 +00:00
adding ConstantMergedTensorDescriptor, refactering ConstantTensorDescriptor, Sequence
This commit is contained in:
@@ -2,94 +2,118 @@
|
||||
#include "common.hip.hpp"
|
||||
#include "ConstantTensorDescriptor.hip.hpp"
|
||||
|
||||
// TensorDesc: ConstantTensorDescriptor<...>
|
||||
// MergedDimRanges: Sequence<FirstMergedDim, LastMergedDim>
|
||||
template <class TensorDesc, class... MergedDimRanges>
|
||||
// OriginalTensorDesc : ConstantTensorDescriptor<...>
|
||||
// it's the tensor whose dimensions are to be merged
|
||||
// OriginalDimMergeSeqs : Sequence<...>...
|
||||
// each is a sequence of original dimensions (of OriginalTensorDesc) to be merged
|
||||
template <class OriginalTensorDesc, class... OriginalDimMergeSeqs>
|
||||
struct ConstantMergedTensorDescriptor
|
||||
{
|
||||
static constexpr index_t nOriginalDim = GetNumOfOriginalDimension();
|
||||
static constexpr index_t nDim = GetNumOfDimension();
|
||||
static constexpr auto mOriginalDimMergeSeqs = std::tuple<OriginalDimMergeSeqs...>{};
|
||||
|
||||
static constexpr index_t nDim = std::tuple_size<mOriginalDimMergeSeqs>::value;
|
||||
static constexpr index_t nOriginalDim = OriginalDesc::GetNumOfDimension();
|
||||
|
||||
template <class... Is>
|
||||
__host__ __device__ constexpr ConstantMergedTensorDescriptor()
|
||||
{
|
||||
constexpr auto merged_dim_ranges = std::make_tuple(MergedDimRanges{}...);
|
||||
static_assert(nDim <= nOriginalDim, "wrong!");
|
||||
|
||||
static_for<0, sizeof...(MergedDimRanges), 1>{}([&](auto I) {
|
||||
constexpr index_t i = I.Get();
|
||||
constexpr auto merged_dim_range = std::get<i>(merged_dim_ranges);
|
||||
// TODO: check each of OriginalDimMergeSeqs contains at least 1, and at most
|
||||
// OriginalTensorDesc::nDim number of dimensions
|
||||
|
||||
static_assert(merged_dim_range.GetSize() == 2,
|
||||
"wrong! should specify first and last dimension to be merged");
|
||||
static_assert(merged_dim_range.Get(Number<0>{}) < GetNumOfUnmergedDimension(),
|
||||
"wrong!");
|
||||
static_assert(merged_dim_range.Get(Number<1>{}) < GetNumOfUnmergedDimension(),
|
||||
"wrong!");
|
||||
static_assert(merged_dim_range.Get(Number<0>{}) <= merged_dim_range.Get(Number<1>{}),
|
||||
"wrong!");
|
||||
});
|
||||
// TODO: check there is no duplication in OriginalDimMergeSeqs
|
||||
|
||||
// TODO: check OriginalDimMergeSeqs contains all original dimensions
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetNumOfDimension()
|
||||
{
|
||||
constexpr auto merged_dim_ranges = std::make_tuple(MergedDimRanges...);
|
||||
__host__ __device__ static constexpr index_t GetNumOfDimension() { return nDim; }
|
||||
|
||||
struct f_calculate_num_of_lost_dim
|
||||
{
|
||||
__host__ __device__ constexpr index_t operator()(auto I) const
|
||||
__host__ __device__ static constexpr index_t GetNumOfOriginalDimension() { return nOriginalDim }
|
||||
|
||||
template <index_t IDim>
|
||||
__host__ __device__ static constexpr bool ContainMultipleOriginalDimensions(Number<IDim>)
|
||||
{
|
||||
return (std::Get<IDIM>(mOriginalDimMergeSeqs).GetSize() > 1);
|
||||
}
|
||||
|
||||
template <index_t IDim>
|
||||
__host__ __device__ static constexpr index_t GetLength(Number<IDim>)
|
||||
{
|
||||
constexpr auto original_dims_partial = std::Get<IDim>(mOriginalDimMergeSeqs);
|
||||
|
||||
return OriginalTensorDesc::Extract(original_dims_partial).GetElementSize();
|
||||
}
|
||||
|
||||
template <index_t IDim>
|
||||
__host__ __device__ static constexpr index_t GetStride(Number<IDim>)
|
||||
{
|
||||
static_assert(!ContainMultipleOriginalDimensions(Number<IDim>{}),
|
||||
"wrong! stride of a merged dimension is undefined");
|
||||
|
||||
constexpr auto idim_original = std::Get<IDim>(mOriginalDimMergeSeqs).Front();
|
||||
|
||||
return OriginalTensorDesc::GetStride(Number<idim_original>{});
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetLengths()
|
||||
{
|
||||
return Sequence<OriginalTensorDesc::Extract(OriginalDimMergeSeqs).GetElementSize()...>{};
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetElementSize()
|
||||
{
|
||||
return OriginalTensorDesc::GetElementSize();
|
||||
}
|
||||
|
||||
__host__ __device__ static auto
|
||||
GetOriginalMultiIndexFromMultiIndex(Array<index_t, nDim> multi_id)
|
||||
{
|
||||
Array<index_t, nOriginalDim> original_multi_id;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto IDim) {
|
||||
constexpr index_t idim = IDim.Get();
|
||||
constexpr auto original_dims_partial = std::get<idim>(mOriginalDimMergeSeqs);
|
||||
|
||||
// get partial original-multi-id corresponding to this merged dimension
|
||||
constexpr auto original_multi_id_partial =
|
||||
OriginalTensorDesc::Extract(original_dims_partial)
|
||||
.GetMultiIndexFrom1dIndex(multi_id[idim]);
|
||||
|
||||
// make sure compiler unroll this loop and propagate all the constants
|
||||
for(index_t i = 0; i < original_dims_partial.GetSize(); ++i)
|
||||
{
|
||||
constexpr index_t i = I.Get();
|
||||
constexpr auto merged_dim_range = std::get<i>(merged_dim_ranges);
|
||||
index_t idim_original = original_dims_partial[i];
|
||||
|
||||
return merged_dim_range.Get(Number<1>{}) - merged_dim_range.Get(Number<0>{});
|
||||
original_multi_id[idim_original] = original_multi_id_partial[i]
|
||||
}
|
||||
};
|
||||
});
|
||||
|
||||
constexpr index_t num_lost_dim = static_const_reduce_n<sizeof...(MergedDimRanges)>{}(
|
||||
f_calculate_num_of_lost_dim, std::plus<index_t>{});
|
||||
|
||||
return TensorDesc::GetNumOfDimension() - num_lost_dim;
|
||||
return original_multi_id;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetNumOfOriginalDimension()
|
||||
__host__ __device__ static index_t GetOffsetFromMultiIndex(Array<index_t, nDim> multi_id)
|
||||
{
|
||||
return TensorDesc::GetNumOfDimension();
|
||||
const auto original_multi_id = GetOriginalMultiIndexFromMultiIndex(multi_id);
|
||||
|
||||
return OriginalTensorDesc::GetOffsetFromMultiIndex(orginal_multi_id);
|
||||
}
|
||||
|
||||
template <index_t IDim>
|
||||
__host__ __device__ static constexpr bool IsMergedDimension(Number<IDim>)
|
||||
template <index_t... Is>
|
||||
__host__ __device__ static index_t GetOffsetFromMultiIndex(Is... is)
|
||||
{
|
||||
// not implemented
|
||||
return GetOffsetFromMultiIndex(Array<index_t, nDim>{is...});
|
||||
}
|
||||
|
||||
template <index_t IDim>
|
||||
__host__ __device__ static constexpr bool GetLength(Number<IDim>)
|
||||
__host__ __device__ static Array<index_t, nDim> GetMultiIndexFrom1dIndex(index_t id)
|
||||
{
|
||||
// not implemented
|
||||
}
|
||||
constexpr auto dummy_desc = make_packed_ConstantTensorDescriptor(GetLengths());
|
||||
|
||||
template <index_t IDim>
|
||||
__host__ __device__ static constexpr bool GetStride(Number<IDim>)
|
||||
{
|
||||
static_assert(!IsMergedDimension(Number<IDim>{}, "wrong! stride of a merged dimension is undefined")
|
||||
// not implemented
|
||||
}
|
||||
|
||||
template <class... Is>
|
||||
__host__ __device__ auto MultiIndex2OriginalMultiIndex(Is... is) const
|
||||
{
|
||||
// not implemented
|
||||
}
|
||||
|
||||
template <class... Is>
|
||||
__host__ __device__ auto OriginalMultiIndex2MultiIndex(Is... is) const
|
||||
{
|
||||
// not implemented
|
||||
return dummy_desc.GetMultiIndexFrom1dIndex(id);
|
||||
}
|
||||
};
|
||||
|
||||
template <class TensorDesc, class... MergedDimRanges>
|
||||
constexpr auto make_ConstantMergedTensorDescriptor(TensorDesc, MergedDimRanges...)
|
||||
template <class OriginalTensorDesc, class... OriginalDimMergeSeqs>
|
||||
constexpr auto make_ConstantMergedTensorDescriptor(OriginalTensorDesc, OriginalDimMergeSeqs...)
|
||||
{
|
||||
return ConstantMergedTensorDescriptor<TensorDesc, MergedDimRanges...>{};
|
||||
return ConstantMergedTensorDescriptor<OriginalTensorDesc, OriginalDimMergeSeqs...>{};
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user