mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 09:16:52 +00:00
adding dimension tranformation
This commit is contained in:
@@ -1,217 +0,0 @@
|
||||
#ifndef CK_DIMENSION_TRANSFORM_HPP
|
||||
#define CK_DIMENSION_TRANSFORM_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t N>
|
||||
using MultiIndex = Array<index_t, N>;
|
||||
|
||||
// LowLengths: Sequence<...>
|
||||
template <class LowLengths>
|
||||
struct PassThrough
|
||||
{
|
||||
static constexpr index_t nDim = LowLengths::GetSize();
|
||||
|
||||
using LowerId = MultiIndex<nDim>;
|
||||
using UpperId = LowerId;
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerNumOfDimension() { return Number<nDim>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetUpperNumOfDimension()
|
||||
{
|
||||
return GetLowerNumOfDimension();
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerLengths() { return LowLengths{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetUpperLengths() { return GetLowerLengths(); }
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerId(UpperId id_up) { return id_up; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerIdDiff(UpperId id_up_diff)
|
||||
{
|
||||
return id_up_diff;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
|
||||
};
|
||||
|
||||
// LowLengths: Sequence<...>
|
||||
template <class LowLengths, class LeftPads, class RightPads>
|
||||
struct Pad
|
||||
{
|
||||
static constexpr index_t nDim = LowLengths::GetSize();
|
||||
|
||||
using LowerId = MultiIndex<nDim>;
|
||||
using UpperId = LowerId;
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerNumOfDimension() { return Number<nDim>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetUpperNumOfDimension()
|
||||
{
|
||||
return GetLowerNumOfDimension();
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerLengths() { return LowLengths{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetUpperLengths()
|
||||
{
|
||||
return GetLowerLengths() + LeftPads + RightPads;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerId(UpperId id_up) { return id_up - LeftPads; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerIdDiff(UpperId id_up_diff)
|
||||
{
|
||||
return id_up_diff;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
|
||||
};
|
||||
|
||||
// LowLengths: Sequence<...>
|
||||
template <class LowLengths>
|
||||
struct Merge
|
||||
{
|
||||
static constexpr index_t nDimLow = LowLengths::GetSize();
|
||||
static constexpr index_t nDimUp = 1;
|
||||
|
||||
using LowerId = MultiIndex<nDimLow>;
|
||||
using UpperId = MultiIndex<nDimUp>;
|
||||
|
||||
__host__ __device__ static constexpr auto GetUpperNumOfDimension(){return Number<nDimUp>{}};
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerNumOfDimension() { return Number<nDimLow>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerLengths() { return LowLengths{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetUpperLengths()
|
||||
{
|
||||
return Sequence<accumulate_on_sequence(
|
||||
GetLowerLengths(), math::multiplies<index_t>{}, Number<1>{})>{};
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerId(UpperId id_up)
|
||||
{
|
||||
LowerId id_low;
|
||||
|
||||
// not implemeneted
|
||||
|
||||
return id_low;
|
||||
}
|
||||
|
||||
// id_low_diff depends on id_low_old, so id_low need to be up-to-date
|
||||
__host__ __device__ static constexpr auto GetLowerIdDiff(UpperId id_up_diff, LowerId id_low_old)
|
||||
{
|
||||
LowerId id_low_diff;
|
||||
|
||||
// not implemeneted
|
||||
|
||||
return id_low_diff;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsLinearTransform() { return false; }
|
||||
};
|
||||
|
||||
// UpLengths: Sequence<...>
|
||||
template <index_t LowLength, class UpLengths>
|
||||
struct Unmerge
|
||||
{
|
||||
static constexpr index_t nDimLow = 1;
|
||||
static constexpr index_t nDimUp = UpLengths::GetSize();
|
||||
|
||||
__host__ __device__ constexpr Unmerge()
|
||||
{
|
||||
static_assert(LowLength == accumulate_on_sequence(
|
||||
UpLengths{}, math::multiplies<index_t>{}, Number<1>{}),
|
||||
"wrong! UpLengths need to be ");
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetUpperNumOfDimension(){return Number<nDimUp>{}};
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerNumOfDimension() { return Number<nDimLow>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerLengths() { return Sequence<LowLength>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetUpperLengths() { return UpLengths{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerId(UpperId id_up)
|
||||
{
|
||||
constexpr auto scans = typename sequence_reverse_inclusive_scan<UpLengths,
|
||||
math::multiplies<index_t>,
|
||||
1>::type{};
|
||||
|
||||
LowerId id_low{0};
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto idim) { id_low[0] += id_up[idim] * scans[idim]; });
|
||||
|
||||
return id_low;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerIdDiff(UpperId id_up_diff)
|
||||
{
|
||||
return GetLowerId(id_up_diff);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
|
||||
};
|
||||
|
||||
// UpLengths: Sequence<...>
|
||||
// Coefficients: Sequence<...>
|
||||
// id_low = coefficients[0, ...nDimUp-1] * id_up[0, ...nDimUp-1] + coefficients[nDimUp]
|
||||
template <index_t LowLength, class UpLengths, class Coefficients>
|
||||
struct Embed
|
||||
{
|
||||
static constexpr index_t nDimLow = 1;
|
||||
static constexpr index_t nDimUp = UpLengths::GetSize();
|
||||
|
||||
static constexpr auto mCoefficients = Coefficients{};
|
||||
|
||||
__host__ __device__ constexpr Embed()
|
||||
{
|
||||
static_assert(UpLengths::GetSize() == nDimUp && Coefficients::GetSize() == nDimUp + 1,
|
||||
"wrong! # of dimensions not consistent");
|
||||
|
||||
constexpr index_t low_id_max =
|
||||
Coefficents.Back() + accumulate_on_sequence(UpLengths{} * Coefficients::PopBack(),
|
||||
math::plus<index_t>{},
|
||||
Number<0>{});
|
||||
|
||||
static_assert(low_id_max < LowLength, "wrong! lower-id will go out of range");
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetUpperNumOfDimension(){return Number<nDimUp>{}};
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerNumOfDimension() { return Number<nDimLow>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerLengths() { return Sequence<LowLength>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetUpperLengths() { return UpLengths{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerId(UpperId id_up)
|
||||
{
|
||||
LowerId id_low{mCoefficients[nDimUp]};
|
||||
|
||||
static_for<0, nDimUp, 1>{}(
|
||||
[&](auto idim) { id_low[0] += id_up[idim] * mCoefficients[idim]; });
|
||||
|
||||
return id_low;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerIdDiff(UpperId id_up_diff)
|
||||
{
|
||||
LowerId id_low_diff{0};
|
||||
|
||||
static_for<0, nDimUp, 1>{}(
|
||||
[&](auto idim) { id_low_diff[0] += id_up_diff[idim] * mCoefficients[idim]; });
|
||||
|
||||
return id_low_diff;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -3,75 +3,159 @@
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dimension.hpp"
|
||||
#include "multi_index_transform.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <class Lengths, class Strides>
|
||||
template <class... NativeDimensions>
|
||||
struct NativeTensorDescriptor
|
||||
{
|
||||
using type = NativeTensorDescriptor;
|
||||
static constexpr index_t nDim = Lengths::GetSize();
|
||||
using type = NativeTensorDescriptor;
|
||||
static constexpr auto mDimensions = Tuple<NativeDimensions...>;
|
||||
static constexpr index_t nDim = mDimensions::GetSize();
|
||||
|
||||
using Id = MultiIndex<nDim>;
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfDimension() { return Number<nDim>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetLengths() { return Lengths{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetStrides() { return Strides{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetLength(index_t IDim) { return Lengths{}[IDim]; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetStride(index_t IDim) { return Strides{}[IDim]; }
|
||||
|
||||
__host__ __device__ static constexpr index_t GetOffset(Id id)
|
||||
__host__ __device__ static constexpr auto GetLengths()
|
||||
{
|
||||
// not implemented
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetStrides()
|
||||
{
|
||||
// not implemented
|
||||
}
|
||||
|
||||
template <index_t IDim>
|
||||
__host__ __device__ static constexpr auto GetLength(Number<IDim>)
|
||||
{
|
||||
return mDimensions.Get(Number<IDim>{}).GetLength();
|
||||
}
|
||||
|
||||
template <index_t IDim>
|
||||
__host__ __device__ static constexpr auto GetStride(Number<IDim>)
|
||||
{
|
||||
return mDimensions.Get(Number<IDim>{}).GetStride();
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetOffset(Index idx)
|
||||
{
|
||||
index_t offset = 0;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto idim) { offset += idx[idim] * GetStride(idim); });
|
||||
|
||||
return offset;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetOffsetDiff(Index idx_diff)
|
||||
{
|
||||
index_t offset_diff = 0;
|
||||
|
||||
static_for<0, nDim, 1>{}(
|
||||
[&](auto idim) { offset_diff += idx_diff[idim] * GetStride(idim); });
|
||||
|
||||
return offset_diff;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto AreUpperIndex2OffsetTransformLinear();
|
||||
{
|
||||
// TODO: re-implement "Sequence", so that it can take other data-type (including bool) as
|
||||
// element
|
||||
return uniform_sequence_gen<nDim, 1>{};
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetIndependentDimensionGroups()
|
||||
{
|
||||
// not implemented, should return Tuple<Sequence<0>, Sequence<1>, ...>
|
||||
return xxx;
|
||||
}
|
||||
};
|
||||
|
||||
// LowerTensorDescriptor
|
||||
// Transforms: std::tuple<DimensionTransforms...>
|
||||
// LowerIds: std::tuple<Sequence<...>>
|
||||
// UpperIds: std::tuple<Sequence<...>>
|
||||
template <class LowTensorDescriptor,
|
||||
class Transforms,
|
||||
class LowDimensionMasks,
|
||||
class UpDimensionMasks>
|
||||
// LowerDimensionIds: std::tuple<Sequence<...>>
|
||||
// UpperDimensionIds: std::tuple<Sequence<...>>
|
||||
template <class LowTensorDescriptor, class Transforms, class LowDimensionIds, class UpDimensionIds>
|
||||
struct TransformedTensorDescriptor
|
||||
{
|
||||
using type = TransformedTensorDescriptor;
|
||||
static constexpr index_t nDimUp = xxxx;
|
||||
static constexpr index_t nDimLow = xxx;
|
||||
static constexpr index_t nDimUp = GetUpperNumOfDimension();
|
||||
static constexpr index_t nDimLow = GetLowerNumOfDimension();
|
||||
|
||||
static constexpr index_t nTransform = Transforms::GetSize();
|
||||
|
||||
using UpperId = MultiIndex<nDimUp>;
|
||||
using LowerId = MultiIndex<nDimLow>;
|
||||
using UpperIndex = MultiIndex<nDimUp>;
|
||||
using LowerIndex = MultiIndex<nDimLow>;
|
||||
|
||||
__host__ __device__ static constexpr TransformedTensorDescriptor()
|
||||
{
|
||||
static_assert(nTransform == Transforms::GetSize() &&
|
||||
nTransform == LowDimensionMasks::GetSize() &&
|
||||
nTransform == UpDimensionMasks::GetSize(),
|
||||
nTransform == LowDimensionIds::GetSize() &&
|
||||
nTransform == UpDimensionIds::GetSize(),
|
||||
"wrong! # of transformations not the same");
|
||||
|
||||
// TODO: sanity check: LowDimensionMasks should include all low-dimensions,
|
||||
// UpDimensionMasks should include all up-dimensions
|
||||
// TODO: sanity check: LowDimensionIds should include all low-dimensions,
|
||||
// UpDimensionIds should include all up-dimensions
|
||||
|
||||
// TODO: sanity check: while a up-dimension could be associated with multille
|
||||
// transformation,
|
||||
// a low-dimension should be associated with only one transformation
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfLowerDimension()
|
||||
{
|
||||
// Here, we assume all lower-dimensions are active
|
||||
// TODO: sanity-check all lower-dimension are indeed active
|
||||
constexpr auto low_active_dims = unique_sort_sequence(
|
||||
merge_tuple_of_sequences(LowDimensionIds{}), math::less<index_t>{});
|
||||
|
||||
return low_active_dims.GetSize();
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfUpperDimension()
|
||||
{
|
||||
constexpr auto up_active_dims =
|
||||
unique_sort_sequence(merge_tuple_of_sequences(UpDimensionIds{}), math::less<index_t>{});
|
||||
return up_active_dims.GetSize();
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfDimension()
|
||||
{
|
||||
// not implemented
|
||||
return GetNumOfUpperDimension();
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetLengths()
|
||||
{
|
||||
// not implemented
|
||||
struct lambda_get_upper_lengths
|
||||
{
|
||||
template <class Transform>
|
||||
__host__ __device__ constexpr auto operator()(Transform tran) const
|
||||
{
|
||||
return tran.GetUpperLengths();
|
||||
}
|
||||
};
|
||||
|
||||
constexpr auto tuple_of_upper_lengths =
|
||||
transform_tuple(Transforms, lambda_get_upper_lengths{});
|
||||
|
||||
constexpr auto all_upper_lengths = merge_tuple_of_sequences(tuple_of_upper_lengths);
|
||||
|
||||
constexpr auto all_upper_dimension_ids = merge_tuple_of_sequences(UpDimensionIds{});
|
||||
|
||||
// TODO: sanity-check all_upper_dimension_ids contain all upper-dimensions
|
||||
// TODO: sanity-check all_upper_lengths have no conflicting upper-length
|
||||
|
||||
using sort_dimension_ids =
|
||||
sequence_unique_sort<decltype(all_upper_dimension_ids), math::less<index_t>>;
|
||||
constexpr auto sorted_upper_dimension_ids = typename sort_dimension_ids::type;
|
||||
constexpr auto sorted2unsorted_map = typename sort_dimension_ids::sorted2unsorted_map_type;
|
||||
|
||||
constexpr auto sorted_upper_lengths =
|
||||
sequence_element_pick(all_upper_lengths, sorted2unsorted_map);
|
||||
|
||||
return sorted_upper_lengths;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerTensorDescriptor()
|
||||
@@ -79,17 +163,57 @@ struct TransformedTensorDescriptor
|
||||
return LowTensorDescriptor{};
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetLowerId(UpperId id_up)
|
||||
__host__ __device__ static constexpr index_t GetLowerIndex(UpperIndex idx_up)
|
||||
{
|
||||
// not implemented
|
||||
LowerIndex idx_low;
|
||||
|
||||
static_for<0, nTransform, 1>{}([&](auto itran) {
|
||||
constexpr auto tran = Transforms::Get(itran);
|
||||
|
||||
constexpr auto idx_low_part = pick_array_element(idx_low, LowDimensionIds::Get(itran));
|
||||
constexpr auto idx_up_part = pick_array_element(idx_up, UpDimensionIds::Get(itran));
|
||||
|
||||
// this assume each lower (single) index is only assocaited with one transformation,
|
||||
// which is required for index transformation, and has been checked during constructor
|
||||
// of TransformedTensorDescriptor
|
||||
idx_low_part = tran.GetLowerIndex(idx_up_part);
|
||||
});
|
||||
|
||||
return idx_low;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetOffset(UpperId id_up)
|
||||
__host__ __device__ static constexpr index_t GetLowerIndexDiff(UpperIndex idx_up_diff,
|
||||
LowerIndex idx_low_old)
|
||||
{
|
||||
return GetLowerTensorDescriptor().GetOffset(GetLowerId(id_up));
|
||||
LowerIndex idx_low_diff;
|
||||
|
||||
static_for<0, nTransform, 1>{}([&](auto itran) {
|
||||
constexpr auto tran = Transforms::Get(itran);
|
||||
|
||||
constexpr auto idx_up_diff_part =
|
||||
pick_array_element(idx_up_diff, UpDimensionIds::Get(itran));
|
||||
|
||||
constexpr auto idx_low_diff_part =
|
||||
pick_array_element(idx_low_diff, LowDimensionIds::Get(itran));
|
||||
|
||||
constexpr auto idx_low_old_part =
|
||||
pick_array_element(idx_low_old, LowDimensionIds::Get(itran));
|
||||
|
||||
// this assume each lower (single) index is associated with only one transformation,
|
||||
// which is required for index transformation, and has been checked during constructor
|
||||
// of TransformedTensorDescriptor
|
||||
idx_low_diff_part = tran.GetLowerIndex(idx_up_diff_part, idx_low_old_part);
|
||||
});
|
||||
|
||||
return idx_low_diff;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto AreUpperId2OffsetLinear();
|
||||
__host__ __device__ static constexpr index_t GetOffset(UpperIndex idx_up)
|
||||
{
|
||||
return GetLowerTensorDescriptor().GetOffset(GetLowerIndex(idx_up));
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto AreUpperIndex2OffsetTransformLinear();
|
||||
{
|
||||
// not implemented
|
||||
}
|
||||
|
||||
@@ -79,6 +79,56 @@ struct Array
|
||||
}
|
||||
};
|
||||
|
||||
// A: Array
|
||||
// Picks: Sequence<...>
|
||||
template <class Arr, class Picks>
|
||||
ArrayElementPicker
|
||||
{
|
||||
__host__ __device__ constexpr ArrayElementPicker(Arr & array) : mData{array}
|
||||
{
|
||||
constexpr index_t imax =
|
||||
accumulate_on_sequence(Picks{}, math::maxer<index_t>{}, Number<0>{});
|
||||
|
||||
static_assert(imax < Picks::GetSize(), "wrong! exceeding max id");
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetSize() { return Picks::GetSize(); }
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr TData operator[](Number<I>) const
|
||||
{
|
||||
constexpr auto IP = Picks::Get(Number<I>{});
|
||||
return mData[IP];
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr TData operator[](index_t i) const
|
||||
{
|
||||
constexpr index_t ip = Picks{}[i];
|
||||
return mData[ip];
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ TData& operator()(Number<I>)
|
||||
{
|
||||
constexpr auto IP = Picks::Get(Number<I>{});
|
||||
return mData[IP];
|
||||
}
|
||||
|
||||
__host__ __device__ TData& operator()(index_t i)
|
||||
{
|
||||
constexpr index_t ip = Picks{}[i];
|
||||
return mData[ip];
|
||||
}
|
||||
|
||||
Arr& mData;
|
||||
};
|
||||
|
||||
template <class Arr, class Picks>
|
||||
__host__ __device__ constexpr auto pick_array_element(Arr& a, Picks)
|
||||
{
|
||||
return ArrayElementPicker<Arr, Picks>(a);
|
||||
}
|
||||
|
||||
template <index_t... Is>
|
||||
__host__ __device__ constexpr auto sequence2array(Sequence<Is...>)
|
||||
{
|
||||
|
||||
@@ -6,6 +6,9 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t, index_t, index_t>
|
||||
struct static_for;
|
||||
|
||||
template <index_t...>
|
||||
struct Sequence;
|
||||
|
||||
@@ -294,6 +297,18 @@ struct sequence_reverse<Sequence<I0, I1>>
|
||||
using type = Sequence<I1, I0>;
|
||||
};
|
||||
|
||||
template <class Seq, class Compare>
|
||||
struct sequence_sort
|
||||
{
|
||||
// not implemented
|
||||
};
|
||||
|
||||
template <class Seq, class Compare>
|
||||
struct sequence_unique_sort
|
||||
{
|
||||
// not implemented
|
||||
};
|
||||
|
||||
template <class Seq>
|
||||
struct is_valid_sequence_map
|
||||
{
|
||||
@@ -486,6 +501,35 @@ __host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce, Number<I
|
||||
return reverse_inclusive_scan_sequence(Seq{}.Reverse(), Reduce{}, Number<Init>{}).Reverse();
|
||||
}
|
||||
|
||||
template <class Seq, class Reduce>
|
||||
struct lambda_accumulate_on_sequence
|
||||
{
|
||||
const Reduce& f;
|
||||
index_t& result;
|
||||
|
||||
__host__ __device__ constexpr lambda_accumulate_on_sequence(const Reduce& f_, index_t& result_)
|
||||
: f(f_), result(result_)
|
||||
{
|
||||
}
|
||||
|
||||
template <class IDim>
|
||||
__host__ __device__ constexpr index_t operator()(IDim) const
|
||||
{
|
||||
return result = f(result, Seq::Get(IDim{}));
|
||||
}
|
||||
};
|
||||
|
||||
template <class Seq, class Reduce, index_t Init>
|
||||
__host__ __device__ constexpr index_t
|
||||
accumulate_on_sequence(Seq, Reduce f, Number<Init> /*initial_value*/)
|
||||
{
|
||||
index_t result = Init;
|
||||
|
||||
static_for<0, Seq::mSize, 1>{}(lambda_accumulate_on_sequence<Seq, Reduce>(f, result));
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
template <index_t... Xs>
|
||||
__host__ __device__ void print_Sequence(const char* s, Sequence<Xs...>)
|
||||
{
|
||||
|
||||
@@ -37,34 +37,5 @@ struct static_for
|
||||
}
|
||||
};
|
||||
|
||||
template <class Seq, class Reduce>
|
||||
struct lambda_accumulate_on_sequence
|
||||
{
|
||||
const Reduce& f;
|
||||
index_t& result;
|
||||
|
||||
__host__ __device__ constexpr lambda_accumulate_on_sequence(const Reduce& f_, index_t& result_)
|
||||
: f(f_), result(result_)
|
||||
{
|
||||
}
|
||||
|
||||
template <class IDim>
|
||||
__host__ __device__ constexpr index_t operator()(IDim) const
|
||||
{
|
||||
return result = f(result, Seq::Get(IDim{}));
|
||||
}
|
||||
};
|
||||
|
||||
template <class Seq, class Reduce, index_t Init>
|
||||
__host__ __device__ constexpr index_t
|
||||
accumulate_on_sequence(Seq, Reduce f, Number<Init> /*initial_value*/)
|
||||
{
|
||||
index_t result = Init;
|
||||
|
||||
static_for<0, Seq::mSize, 1>{}(lambda_accumulate_on_sequence<Seq, Reduce>(f, result));
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -31,6 +31,12 @@ struct multiplies
|
||||
__host__ __device__ constexpr T operator()(T a, T b) const { return a * b; }
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct maxer
|
||||
{
|
||||
__host__ __device__ constexpr T operator()(T a, T b) const { return a >= b ? a : b; }
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct integer_divide_ceiler
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user