adding ConstantMergedTensorDescriptor, refactering ConstantTensorDescriptor, Sequence

This commit is contained in:
Chao Liu
2019-05-21 16:17:58 -05:00
parent cd29b09a82
commit acd7082fe1
38 changed files with 1238 additions and 768 deletions

View File

@@ -2,94 +2,118 @@
#include "common.hip.hpp"
#include "ConstantTensorDescriptor.hip.hpp"
// TensorDesc: ConstantTensorDescriptor<...>
// MergedDimRanges: Sequence<FirstMergedDim, LastMergedDim>
template <class TensorDesc, class... MergedDimRanges>
// OriginalTensorDesc : ConstantTensorDescriptor<...>
// it's the tensor whose dimensions are to be merged
// OriginalDimMergeSeqs : Sequence<...>...
// each is a sequence of original dimensions (of OriginalTensorDesc) to be merged
template <class OriginalTensorDesc, class... OriginalDimMergeSeqs>
struct ConstantMergedTensorDescriptor
{
static constexpr index_t nOriginalDim = GetNumOfOriginalDimension();
static constexpr index_t nDim = GetNumOfDimension();
static constexpr auto mOriginalDimMergeSeqs = std::tuple<OriginalDimMergeSeqs...>{};
static constexpr index_t nDim = std::tuple_size<mOriginalDimMergeSeqs>::value;
static constexpr index_t nOriginalDim = OriginalDesc::GetNumOfDimension();
template <class... Is>
__host__ __device__ constexpr ConstantMergedTensorDescriptor()
{
constexpr auto merged_dim_ranges = std::make_tuple(MergedDimRanges{}...);
static_assert(nDim <= nOriginalDim, "wrong!");
static_for<0, sizeof...(MergedDimRanges), 1>{}([&](auto I) {
constexpr index_t i = I.Get();
constexpr auto merged_dim_range = std::get<i>(merged_dim_ranges);
// TODO: check each of OriginalDimMergeSeqs contains at least 1, and at most
// OriginalTensorDesc::nDim number of dimensions
static_assert(merged_dim_range.GetSize() == 2,
"wrong! should specify first and last dimension to be merged");
static_assert(merged_dim_range.Get(Number<0>{}) < GetNumOfUnmergedDimension(),
"wrong!");
static_assert(merged_dim_range.Get(Number<1>{}) < GetNumOfUnmergedDimension(),
"wrong!");
static_assert(merged_dim_range.Get(Number<0>{}) <= merged_dim_range.Get(Number<1>{}),
"wrong!");
});
// TODO: check there is no duplication in OriginalDimMergeSeqs
// TODO: check OriginalDimMergeSeqs contains all original dimensions
}
__host__ __device__ static constexpr index_t GetNumOfDimension()
{
constexpr auto merged_dim_ranges = std::make_tuple(MergedDimRanges...);
__host__ __device__ static constexpr index_t GetNumOfDimension() { return nDim; }
struct f_calculate_num_of_lost_dim
{
__host__ __device__ constexpr index_t operator()(auto I) const
__host__ __device__ static constexpr index_t GetNumOfOriginalDimension() { return nOriginalDim }
template <index_t IDim>
__host__ __device__ static constexpr bool ContainMultipleOriginalDimensions(Number<IDim>)
{
return (std::Get<IDIM>(mOriginalDimMergeSeqs).GetSize() > 1);
}
template <index_t IDim>
__host__ __device__ static constexpr index_t GetLength(Number<IDim>)
{
constexpr auto original_dims_partial = std::Get<IDim>(mOriginalDimMergeSeqs);
return OriginalTensorDesc::Extract(original_dims_partial).GetElementSize();
}
template <index_t IDim>
__host__ __device__ static constexpr index_t GetStride(Number<IDim>)
{
static_assert(!ContainMultipleOriginalDimensions(Number<IDim>{}),
"wrong! stride of a merged dimension is undefined");
constexpr auto idim_original = std::Get<IDim>(mOriginalDimMergeSeqs).Front();
return OriginalTensorDesc::GetStride(Number<idim_original>{});
}
__host__ __device__ static constexpr auto GetLengths()
{
return Sequence<OriginalTensorDesc::Extract(OriginalDimMergeSeqs).GetElementSize()...>{};
}
__host__ __device__ static constexpr index_t GetElementSize()
{
return OriginalTensorDesc::GetElementSize();
}
__host__ __device__ static auto
GetOriginalMultiIndexFromMultiIndex(Array<index_t, nDim> multi_id)
{
Array<index_t, nOriginalDim> original_multi_id;
static_for<0, nDim, 1>{}([&](auto IDim) {
constexpr index_t idim = IDim.Get();
constexpr auto original_dims_partial = std::get<idim>(mOriginalDimMergeSeqs);
// get partial original-multi-id corresponding to this merged dimension
constexpr auto original_multi_id_partial =
OriginalTensorDesc::Extract(original_dims_partial)
.GetMultiIndexFrom1dIndex(multi_id[idim]);
// make sure compiler unroll this loop and propagate all the constants
for(index_t i = 0; i < original_dims_partial.GetSize(); ++i)
{
constexpr index_t i = I.Get();
constexpr auto merged_dim_range = std::get<i>(merged_dim_ranges);
index_t idim_original = original_dims_partial[i];
return merged_dim_range.Get(Number<1>{}) - merged_dim_range.Get(Number<0>{});
original_multi_id[idim_original] = original_multi_id_partial[i]
}
};
});
constexpr index_t num_lost_dim = static_const_reduce_n<sizeof...(MergedDimRanges)>{}(
f_calculate_num_of_lost_dim, std::plus<index_t>{});
return TensorDesc::GetNumOfDimension() - num_lost_dim;
return original_multi_id;
}
__host__ __device__ static constexpr index_t GetNumOfOriginalDimension()
__host__ __device__ static index_t GetOffsetFromMultiIndex(Array<index_t, nDim> multi_id)
{
return TensorDesc::GetNumOfDimension();
const auto original_multi_id = GetOriginalMultiIndexFromMultiIndex(multi_id);
return OriginalTensorDesc::GetOffsetFromMultiIndex(orginal_multi_id);
}
template <index_t IDim>
__host__ __device__ static constexpr bool IsMergedDimension(Number<IDim>)
template <index_t... Is>
__host__ __device__ static index_t GetOffsetFromMultiIndex(Is... is)
{
// not implemented
return GetOffsetFromMultiIndex(Array<index_t, nDim>{is...});
}
template <index_t IDim>
__host__ __device__ static constexpr bool GetLength(Number<IDim>)
__host__ __device__ static Array<index_t, nDim> GetMultiIndexFrom1dIndex(index_t id)
{
// not implemented
}
constexpr auto dummy_desc = make_packed_ConstantTensorDescriptor(GetLengths());
template <index_t IDim>
__host__ __device__ static constexpr bool GetStride(Number<IDim>)
{
static_assert(!IsMergedDimension(Number<IDim>{}, "wrong! stride of a merged dimension is undefined")
// not implemented
}
template <class... Is>
__host__ __device__ auto MultiIndex2OriginalMultiIndex(Is... is) const
{
// not implemented
}
template <class... Is>
__host__ __device__ auto OriginalMultiIndex2MultiIndex(Is... is) const
{
// not implemented
return dummy_desc.GetMultiIndexFrom1dIndex(id);
}
};
template <class TensorDesc, class... MergedDimRanges>
constexpr auto make_ConstantMergedTensorDescriptor(TensorDesc, MergedDimRanges...)
template <class OriginalTensorDesc, class... OriginalDimMergeSeqs>
constexpr auto make_ConstantMergedTensorDescriptor(OriginalTensorDesc, OriginalDimMergeSeqs...)
{
return ConstantMergedTensorDescriptor<TensorDesc, MergedDimRanges...>{};
return ConstantMergedTensorDescriptor<OriginalTensorDesc, OriginalDimMergeSeqs...>{};
}