rework sequence

This commit is contained in:
Chao Liu
2019-05-18 23:21:02 -05:00
parent df73287b82
commit a6b95c393b
3 changed files with 167 additions and 229 deletions

View File

@@ -1,30 +1,11 @@
#pragma once
#include "common.hip.hpp"
template <class PreviousStrides, class RemainLengths>
__host__ __device__ constexpr auto calculate_default_strides_impl(PreviousStrides, RemainLengths)
{
constexpr index_t previous_stride = PreviousStrides{}.Front();
constexpr index_t current_length = RemainLengths{}.Back();
constexpr index_t current_stride = current_length * previous_stride;
return calculate_default_strides_impl(PreviousStrides{}.PushFront(Number<current_stride>{}),
RemainLengths{}.PopBack());
}
template <class PreviousStrides, index_t L0, index_t L1>
__host__ __device__ constexpr auto calculate_default_strides_impl(PreviousStrides, Sequence<L0, L1>)
{
constexpr index_t previous_stride = PreviousStrides{}.Front();
constexpr index_t current_stride = L1 * previous_stride;
return PreviousStrides{}.PushFront(Number<current_stride>{});
}
template <class Lengths>
__host__ __device__ constexpr auto calculate_default_strides(Lengths)
{
return calculate_default_strides_impl(Sequence<1>{}, Lengths{});
return reverse_inclusive_scan_sequence(Lengths{}.PopFront().PushBack(Number<1>{}),
std::multiplies<index_t>{});
}
// this is ugly, only for 2d
@@ -57,7 +38,8 @@ __host__ __device__ constexpr auto calculate_default_strides_aligned(Sequence<L0
template <class Lengths, class Strides>
struct ConstantTensorDescriptor
{
using Type = ConstantTensorDescriptor;
using Type = ConstantTensorDescriptor;
static constexpr index_t nDim = Lengths::GetSize();
__host__ __device__ constexpr ConstantTensorDescriptor()
@@ -193,7 +175,8 @@ struct ConstantTensorDescriptor
// folded strides
constexpr auto fold_strides =
Number<unfold_stride>{} *
reverse_scan_sequence(fold_intervals.PushBack(Number<1>{}), std::multiplies<index_t>{});
reverse_inclusive_scan_sequence(fold_intervals.PushBack(Number<1>{}),
std::multiplies<index_t>{});
// left and right
constexpr auto left = make_increasing_sequence(Number<0>{}, Number<IDim>{}, Number<1>{});