mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 01:10:17 +00:00
rework sequence
This commit is contained in:
@@ -1,30 +1,11 @@
|
||||
#pragma once
|
||||
#include "common.hip.hpp"
|
||||
|
||||
template <class PreviousStrides, class RemainLengths>
|
||||
__host__ __device__ constexpr auto calculate_default_strides_impl(PreviousStrides, RemainLengths)
|
||||
{
|
||||
constexpr index_t previous_stride = PreviousStrides{}.Front();
|
||||
constexpr index_t current_length = RemainLengths{}.Back();
|
||||
constexpr index_t current_stride = current_length * previous_stride;
|
||||
|
||||
return calculate_default_strides_impl(PreviousStrides{}.PushFront(Number<current_stride>{}),
|
||||
RemainLengths{}.PopBack());
|
||||
}
|
||||
|
||||
template <class PreviousStrides, index_t L0, index_t L1>
|
||||
__host__ __device__ constexpr auto calculate_default_strides_impl(PreviousStrides, Sequence<L0, L1>)
|
||||
{
|
||||
constexpr index_t previous_stride = PreviousStrides{}.Front();
|
||||
constexpr index_t current_stride = L1 * previous_stride;
|
||||
|
||||
return PreviousStrides{}.PushFront(Number<current_stride>{});
|
||||
}
|
||||
|
||||
template <class Lengths>
|
||||
__host__ __device__ constexpr auto calculate_default_strides(Lengths)
|
||||
{
|
||||
return calculate_default_strides_impl(Sequence<1>{}, Lengths{});
|
||||
return reverse_inclusive_scan_sequence(Lengths{}.PopFront().PushBack(Number<1>{}),
|
||||
std::multiplies<index_t>{});
|
||||
}
|
||||
|
||||
// this is ugly, only for 2d
|
||||
@@ -57,7 +38,8 @@ __host__ __device__ constexpr auto calculate_default_strides_aligned(Sequence<L0
|
||||
template <class Lengths, class Strides>
|
||||
struct ConstantTensorDescriptor
|
||||
{
|
||||
using Type = ConstantTensorDescriptor;
|
||||
using Type = ConstantTensorDescriptor;
|
||||
|
||||
static constexpr index_t nDim = Lengths::GetSize();
|
||||
|
||||
__host__ __device__ constexpr ConstantTensorDescriptor()
|
||||
@@ -193,7 +175,8 @@ struct ConstantTensorDescriptor
|
||||
// folded strides
|
||||
constexpr auto fold_strides =
|
||||
Number<unfold_stride>{} *
|
||||
reverse_scan_sequence(fold_intervals.PushBack(Number<1>{}), std::multiplies<index_t>{});
|
||||
reverse_inclusive_scan_sequence(fold_intervals.PushBack(Number<1>{}),
|
||||
std::multiplies<index_t>{});
|
||||
|
||||
// left and right
|
||||
constexpr auto left = make_increasing_sequence(Number<0>{}, Number<IDim>{}, Number<1>{});
|
||||
|
||||
Reference in New Issue
Block a user