try using more constexpr

This commit is contained in:
Chao Liu
2019-06-04 17:02:49 -05:00
parent 917d7a2b1d
commit 498e71b098
10 changed files with 272 additions and 42 deletions

View File

@@ -4,7 +4,8 @@
template <class Lengths>
__host__ __device__ constexpr auto calculate_tensor_strides_packed(Lengths)
{
return reverse_inclusive_scan_sequence(Lengths{}.PopFront(), mod_conv::multiplies<index_t>{})
return reverse_inclusive_scan_sequence(
Lengths{}.PopFront(), mod_conv::multiplies<index_t>{}, Number<1>{})
.PushBack(Number<1>{});
}
@@ -91,8 +92,10 @@ struct ConstantTensorDescriptor
return align.Get() * ((element_space_unaligned + align.Get() - 1) / align.Get());
}
#if 0
template <index_t NSize>
__host__ __device__ static index_t GetOffsetFromMultiIndex(Array<index_t, NSize> multi_id)
__host__ __device__ static constexpr index_t
GetOffsetFromMultiIndex(Array<index_t, NSize> multi_id)
{
static_assert(NSize == nDim, "wrong! Dimension not consistent");
@@ -105,9 +108,43 @@ struct ConstantTensorDescriptor
return offset;
}
#else
template <index_t NSize>
struct GetOffsetFromMultiIndex_impl
{
Array<index_t, NSize>& multi_id_ref;
index_t& offset_ref;
__host__ __device__ constexpr GetOffsetFromMultiIndex_impl(Array<index_t, NSize>& multi_id,
index_t& offset)
: multi_id_ref(multi_id), offset_ref(offset)
{
}
template <index_t IDim>
__host__ __device__ constexpr bool operator()(Number<IDim>) const
{
offset_ref += multi_id_ref.Get(Number<IDim>{}) * Type::GetStride(Number<IDim>{});
return true;
}
};
template <index_t NSize>
__host__ __device__ static constexpr index_t
GetOffsetFromMultiIndex(Array<index_t, NSize> multi_id)
{
static_assert(NSize == nDim, "wrong! Dimension not consistent");
index_t offset = 0;
static_for<0, nDim, 1>{}(GetOffsetFromMultiIndex_impl<NSize>(multi_id, offset));
return offset;
}
#endif
template <class... Is>
__host__ __device__ static index_t GetOffsetFromMultiIndex(Is... is)
__host__ __device__ static constexpr index_t GetOffsetFromMultiIndex(Is... is)
{
return GetOffsetFromMultiIndex(Array<index_t, sizeof...(Is)>{is...});
}
@@ -123,7 +160,8 @@ struct ConstantTensorDescriptor
multi_id * GetStrides(), mod_conv::plus<index_t>{}, Number<0>{});
}
__host__ __device__ static Array<index_t, nDim> GetMultiIndexFrom1dIndex(index_t id)
#if 0
__host__ __device__ static constexpr Array<index_t, nDim> GetMultiIndexFrom1dIndex(index_t id)
{
Array<index_t, nDim> multi_id;
@@ -141,8 +179,58 @@ struct ConstantTensorDescriptor
return multi_id;
}
#else
struct GetMultiIndexFrom1dIndex_impl
{
using DummyStrides = decltype(calculate_tensor_strides_packed(GetLengths()));
__host__ __device__ static auto
index_t& id_ref;
Array<index_t, nDim>& multi_id_ref;
__host__ __device__ constexpr GetMultiIndexFrom1dIndex_impl(index_t& id,
Array<index_t, nDim>& multi_id)
: id_ref(id), multi_id_ref(multi_id)
{
}
template <index_t IDim>
__host__ __device__ constexpr bool operator()(Number<IDim>) const
{
constexpr index_t stride = DummyStrides::Get(Number<IDim>{});
multi_id_ref.Set(Number<IDim>{}, id_ref / stride);
id_ref -= multi_id_ref.Get(Number<IDim>{}) * stride;
return true;
}
};
__host__ __device__ static constexpr Array<index_t, nDim> GetMultiIndexFrom1dIndex(index_t id)
{
Array<index_t, nDim> multi_id;
constexpr auto dummy_strides = calculate_tensor_strides_packed(GetLengths());
// calculate index in each of the dimensions in the order of their dimension
static_for<0, nDim - 1, 1>{}(GetMultiIndexFrom1dIndex_impl(id, multi_id));
index_t itmp = id / dummy_strides.Get(Number<nDim - 1>{});
multi_id.Set(Number<nDim - 1>{}, itmp);
return multi_id;
}
#endif
#if 0
// return type is Sequence<...>
template<index_t Id>
__host__ __device__ static constexpr auto GetMultiIndexFrom1dIndex(Number<Id>)
{
return inclusive_scan_sequence(f_impl, GetStrides(), Number<Id>{});
}
#endif
__host__ __device__ static constexpr auto
GetOriginalMultiIndexFromMultiIndex(Array<index_t, nDim> multi_id)
{
return multi_id;
@@ -278,8 +366,8 @@ struct ConstantTensorDescriptor
// folded strides
constexpr auto fold_strides =
Number<unfold_stride>{} *
reverse_inclusive_scan_sequence(fold_intervals.PushBack(Number<1>{}),
mod_conv::multiplies<index_t>{});
reverse_inclusive_scan_sequence(
fold_intervals.PushBack(Number<1>{}), mod_conv::multiplies<index_t>{}, Number<1>{});
// left and right
constexpr auto left = typename arithmetic_sequence_gen<0, IDim, 1>::SeqType{};