mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 08:50:17 +00:00
adding implicit gemm v3
This commit is contained in:
@@ -30,11 +30,6 @@ struct ConstantMergedTensorDescriptor
|
||||
});
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetNumOfOriginalDimension()
|
||||
{
|
||||
return TensorDesc::GetNumOfDimension();
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetNumOfDimension()
|
||||
{
|
||||
constexpr auto merged_dim_ranges = std::make_tuple(MergedDimRanges...);
|
||||
@@ -51,11 +46,16 @@ struct ConstantMergedTensorDescriptor
|
||||
};
|
||||
|
||||
constexpr index_t num_lost_dim = static_const_reduce_n<sizeof...(MergedDimRanges)>{}(
|
||||
f_calculate_num_of_lost_dim, mod_conv::plus<index_t>{});
|
||||
f_calculate_num_of_lost_dim, std::plus<index_t>{});
|
||||
|
||||
return TensorDesc::GetNumOfDimension() - num_lost_dim;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetNumOfOriginalDimension()
|
||||
{
|
||||
return TensorDesc::GetNumOfDimension();
|
||||
}
|
||||
|
||||
template <index_t IDim>
|
||||
__host__ __device__ static constexpr bool IsMergedDimension(Number<IDim>)
|
||||
{
|
||||
@@ -71,7 +71,7 @@ struct ConstantMergedTensorDescriptor
|
||||
template <index_t IDim>
|
||||
__host__ __device__ static constexpr bool GetStride(Number<IDim>)
|
||||
{
|
||||
static_assert(!IsMergedDimension(Number<IDim>{}, "wrong! A merged dimension does not have uniform stride")
|
||||
static_assert(!IsMergedDimension(Number<IDim>{}, "wrong! stride of a merged dimension is undefined")
|
||||
// not implemented
|
||||
}
|
||||
|
||||
|
||||
@@ -85,24 +85,35 @@ struct ConstantTensorDescriptor
|
||||
|
||||
__host__ __device__ static constexpr index_t GetElementSize()
|
||||
{
|
||||
return accumulate_on_sequence(Lengths{}, mod_conv::multiplies<index_t>{}, Number<1>{});
|
||||
return accumulate_on_sequence(Lengths{}, std::multiplies<index_t>{}, Number<1>{});
|
||||
}
|
||||
|
||||
#if 0
|
||||
// c++14 doesn't support constexpr lambdas, has to use this trick instead
|
||||
struct GetElementSpace_f
|
||||
struct f_GetElementSpace_impl
|
||||
{
|
||||
template <class IDim>
|
||||
__host__ __device__ constexpr index_t operator()(IDim idim) const
|
||||
{
|
||||
return (Type{}.GetLength(idim) - 1) * Type{}.GetStride(idim);
|
||||
}
|
||||
__host__ __device__ constexpr index_t operator()(index_t length, index_t stride) const
|
||||
{
|
||||
return (length - 1) * stride;
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
template <class Align = Number<1>>
|
||||
__host__ __device__ static constexpr index_t GetElementSpace(Align align = Align{})
|
||||
{
|
||||
#if 0
|
||||
index_t element_space_unaligned =
|
||||
static_const_reduce_n<nDim>{}(GetElementSpace_f{}, mod_conv::plus<index_t>{}) + 1;
|
||||
static_const_reduce_n<nDim>{}(f_GetElementSpace_impl{}, std::plus<index_t>{}) + 1;
|
||||
#else
|
||||
constexpr index_t element_space_unaligned = accumulate_on_sequence(
|
||||
(GetLengths() - Number<1>{}) * GetStrides(), std::plus<index_t>{}, Number<1>{});
|
||||
#endif
|
||||
|
||||
return align.Get() * ((element_space_unaligned + align.Get() - 1) / align.Get());
|
||||
}
|
||||
@@ -140,9 +151,9 @@ struct ConstantTensorDescriptor
|
||||
constexpr auto multi_id = Sequence<Is...>{};
|
||||
|
||||
constexpr auto seq_tmp =
|
||||
transform_sequences(mod_conv::multiplies<index_t>{}, multi_id, GetStrides());
|
||||
transform_sequences(std::multiplies<index_t>{}, multi_id, GetStrides());
|
||||
|
||||
return accumulate_on_sequence(seq_tmp, mod_conv::plus<index_t>{}, Number<0>{});
|
||||
return accumulate_on_sequence(seq_tmp, std::plus<index_t>{}, Number<0>{});
|
||||
}
|
||||
|
||||
__host__ __device__ static Array<index_t, nDim> GetMultiIndex(index_t id)
|
||||
@@ -167,34 +178,112 @@ struct ConstantTensorDescriptor
|
||||
}
|
||||
|
||||
template <index_t IDims...>
|
||||
__host__ __device__ static constexpr auto Extract(Number<IDims>... /*extracted_dims...*/)
|
||||
__host__ __device__ static constexpr auto Extract(Number<IDims>... extract_dims)
|
||||
{
|
||||
static_assert(sizeof...(IDims) <= GetNumOfDimension(), "wrong!");
|
||||
static_assert(sizeof...(IDims) <= GetNumOfDimension(),
|
||||
"wrong! too many number of dimensions to be extracted");
|
||||
|
||||
constexpr auto extracted_lengths = Sequence<Lengths{}.Get(Number<IDims>{})...>{};
|
||||
constexpr auto extracted_strides = Sequence<Strides{}.Get(Number<IDims>{})...>{};
|
||||
|
||||
return make_ConstantTensorDescriptor(extracted_lenghts, extracted_strides);
|
||||
return make_ConstantTensorDescriptor(Lengths{}.Extract(extract_dims),
|
||||
Strides{}.Extract(extract_dims));
|
||||
}
|
||||
|
||||
template <index_t IDim, index_t SliceLen>
|
||||
__host__ __device__ static constexpr auto Slice(Number<IDim>, Number<SliceLen>)
|
||||
{
|
||||
// not implemented
|
||||
return make_ConstantTensorDescriptor(Lengths{}.Modify(Number<IDim>{}, Number<SliceLen>{}),
|
||||
Strides{});
|
||||
}
|
||||
|
||||
template <index_t IDim, index_t... FoldLengths>
|
||||
__host__ device__ static constexpr auto Fold(Number<IDim>, Sequence<FoldLengths...>)
|
||||
template <index_t IDim, index_t... FoldIntervals>
|
||||
__host__ device__ static constexpr auto Fold(Number<IDim>, Number<FoldIntervals>...)
|
||||
{
|
||||
// not implemented
|
||||
// need to check the Length dimension to be folded is dividable by FoldLengths
|
||||
constexpr auto fold_intervals = Sequence<FoldIntervals...>{};
|
||||
|
||||
constexpr fold_intervals_product =
|
||||
accumulate_on_sequence(fold_intervals, std::multiplies<index_t>{}, Number<1>{});
|
||||
|
||||
constexpr auto unfold_length = GetLength(Number<IDim>{});
|
||||
constexpr auto unfold_stride = GetStride(Number<IDim>{});
|
||||
|
||||
// length of the dimension to be folded needs to be dividable by fold_interval_product,
|
||||
// otherwise, folding is invalid
|
||||
static_assert(unfold_length % fold_interval_product == 0,
|
||||
"wrong! length on the dimension to be folded cannot be evenly divided!");
|
||||
|
||||
// folded lengths
|
||||
constexpr auto fold_lengths =
|
||||
Sequence<unfold_length / fold_interval_product>{}.Append(fold_intervals);
|
||||
|
||||
// folded strides
|
||||
constexpr auto fold_strides = transform_sequences(mod_conv::scales<index_t, unfold_stride>{},
|
||||
reverse_scan_sequence(fold_intervals.PushBack(Number<1>{}), std::multiplies<index_t>{});
|
||||
|
||||
// left and right lengths
|
||||
constexpr auto lengths_pair = GetLengths().Split(Number<I>{});
|
||||
constexpr auto left_lengths = lengths_pair.first;
|
||||
constexpr auto right_lengths = lengths_pair.second.PopFront();
|
||||
|
||||
// left and right strides
|
||||
constexpr auto strides_pair = GetStrides().Split(Number<I>{});
|
||||
constexpr auto left_strides = strides_pair.first;
|
||||
constexpr auto right_strides = strides_pair.second.PopFront();
|
||||
|
||||
return make_ConstantTensorDescriptor(left_lengths.Append(fold_lengths).Append(right_lengths),
|
||||
left_strides.Append(fold_strides).Append(right_strides));
|
||||
}
|
||||
|
||||
template <index_t FirstUnfoldDim, index_t LastUnfoldDim>
|
||||
__host__ __device__ static constexpr auto Unfold(Number<FirstUnfoldDim>, Number<LastUnfoldDim>)
|
||||
{
|
||||
// not implemented
|
||||
// need to check the dimensions to be unfold are packed, otherwise, Unfold is not permitted
|
||||
static_assert(FirstUnfoldDim >= 0 && LastUnfoldDim < nDim &&
|
||||
FirstUnfoldDim <= LastUnfoldDim,
|
||||
"wrong! should have FirstUnfoldDim <= LastUnfoldDim!");
|
||||
|
||||
// dimensions to be unfold need to be in descending order (w.r.t. strides), and need to be
|
||||
// packed in memory, otherwise, unfolding is invalid
|
||||
static_for<FirstUnfoldDim, LastUnfoldDim, 1>{}([&](auto IDim) {
|
||||
static_assert(
|
||||
GetStride(IDim) >= GetStride(Number<IDim.Get() + 1>{}),
|
||||
"wrong! dimensions to be unfolded need to be in descending order w.r.t strides");
|
||||
|
||||
static_assert(GetStride(IDim + 1) * GetLength(IDim + 1) == GetStride(IDim),
|
||||
"wrong! dimensions to be unfolded need to be packed");
|
||||
});
|
||||
|
||||
// lengths
|
||||
constexpr auto lens_pair1 = Lengths{}.Split(Number<LastUnfoldDim + 1>{});
|
||||
|
||||
constexpr auto right_lengths = lens_pair1.second;
|
||||
|
||||
constexpr auto lens_pair2 = lens_pair1.first.Split(Number<FirstUnfoldDim>{});
|
||||
|
||||
constexpr auto left_lengths = lens_pair2.first;
|
||||
|
||||
constexpr auto fold_lengths = lens_pair2.second;
|
||||
|
||||
constexpr index_t unfold_length =
|
||||
accumulate_on_sequence(fold_lengths, std::multiplies<index_t>{}, Number<1>{});
|
||||
|
||||
constexpr auto new_strides =
|
||||
left_strides.PopBack(Number<unfold_strides>{}).Append(right_strides);
|
||||
|
||||
// strides
|
||||
constexpr auto strides_pair1 = Strides{}.Split(Number<LastUnfoldDim + 1>{});
|
||||
|
||||
constexpr auto right_strides = strides_pair1.second;
|
||||
|
||||
constexpr auto strides_pair2 = strides_pair1.first.Split(Number<FirstUnfoldDim>{});
|
||||
|
||||
constexpr auto left_strides = strides_pair2.first;
|
||||
|
||||
constexpr auto fold_strides = strides_pair2.second;
|
||||
|
||||
constexpr index_t unfold_stride = fold_strides.Back();
|
||||
|
||||
constexpr auto new_strides =
|
||||
left_strides.PushBack(Number<unfold_strides>{}).Append(right_strides);
|
||||
|
||||
return make_ConstantTensorDescriptor(new_lengths, new_strides);
|
||||
}
|
||||
|
||||
template <index_t... IRs>
|
||||
|
||||
@@ -2,6 +2,15 @@
|
||||
#include "constant_integral.hip.hpp"
|
||||
#include "functional.hip.hpp"
|
||||
|
||||
struct EmptySequence
|
||||
{
|
||||
template <class Seq>
|
||||
__host__ __device__ constexpr Seq Append(Seq) const
|
||||
{
|
||||
return {};
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t... Is>
|
||||
struct Sequence
|
||||
{
|
||||
@@ -39,6 +48,11 @@ struct Sequence
|
||||
assert(false);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr auto Reverse() const
|
||||
{
|
||||
// not implemented
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr index_t Front() const { return mData[0]; }
|
||||
|
||||
__host__ __device__ constexpr index_t Back() const { return mData[mSize - 1]; }
|
||||
@@ -59,25 +73,192 @@ struct Sequence
|
||||
|
||||
__host__ __device__ constexpr auto PopBack() const;
|
||||
|
||||
template <index_t I, index_t X>
|
||||
__host__ __device__ constexpr auto Insert(Number<I>, Number<X>) const
|
||||
template <index_t Xs...>
|
||||
__host__ __device__ constexpr auto Append(Sequence<Xs...>) const
|
||||
{
|
||||
index_t data[mSize + 1];
|
||||
return Sequence<Is..., Xs...>{};
|
||||
}
|
||||
|
||||
static_for<0, I, 1>{}([&](auto Iter) {
|
||||
constexpr index_t iter = Iter.Get();
|
||||
data[iter] = mData[iter];
|
||||
});
|
||||
__host__ __device__ constexpr auto Append(EmptySequence) const { return Type{}; }
|
||||
|
||||
data[I] = X;
|
||||
template <index_t... Ns>
|
||||
__host__ __device__ constexpr auto Extract(Number<Ns>...) const
|
||||
{
|
||||
return Sequence<Type{}.Get(Number<Ns>)...>{};
|
||||
}
|
||||
|
||||
static_for<I, nSize, 1>{}([&](auto Iter) {
|
||||
constexpr index_t iter = Iter.Get();
|
||||
data[iter + 1] = mData[iter];
|
||||
template <index_t N>
|
||||
struct split_impl
|
||||
{
|
||||
template <class FirstSeq, class SecondSeq>
|
||||
__host__ __device__ constexpr auto operator()(FirstSeq, SecondSeq) const
|
||||
{
|
||||
constexpr new_first = FirstSeq{}.PushBack(Number<Second{}.Front()>{});
|
||||
constexpr new_second = SecondSeq{}.PopFront();
|
||||
|
||||
static_if<(N > 0)>{}([&](auto fwd) {
|
||||
return split_impl<N - 1>{}(new_first, fwd(new_second));
|
||||
}).else_([&](auto fwd) { return std::make_pair(new_first, fwd(new_second)); });
|
||||
}
|
||||
};
|
||||
|
||||
// split one sequence to two sequnces: [0, I) and [I, nSize)
|
||||
// return type is std::pair
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr auto Split(Number<I>) const
|
||||
{
|
||||
static_assert(I <= nSize, "wrong! split position is too high!");
|
||||
|
||||
static_if<(I == 0)>{}(
|
||||
[&](auto fwd) { return std::make_pair(EmptySequence<>{}, fwd(Type{})); });
|
||||
|
||||
static_if<(I == nSize)>{}(
|
||||
[&](auto fwd) { return std::make_pair(Type<>{}, fwd(EmptySequence<>{})); });
|
||||
|
||||
static_if<(I > 0 && I < nSize)>{}([&](auto fforwader) {
|
||||
constexpr auto first = Sequence<Type{}.Front()> {}
|
||||
constexpr auto second = Type{}.PopFront();
|
||||
|
||||
return split_impl<I - 1>{}(first, fwd(second));
|
||||
});
|
||||
}
|
||||
|
||||
template <index_t I, index_t X>
|
||||
__host__ __device__ constexpr auto Modify(Number<I>, Number<X>) const
|
||||
{
|
||||
constexpr auto first_second = Split(Number<I>{});
|
||||
|
||||
constexpr auto left = first_second.first;
|
||||
constexpr auto right = first_second.second.PopFront();
|
||||
|
||||
return left.PushBack(Number<X>{}).Append(right);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t IBegin, index_t IEnd, index_t Increment>
|
||||
__host__ __device__ auto make_increasing_sequence(Number<IBegin>, Number<IEnd>, Number<Increment>)
|
||||
{
|
||||
static_assert(IBegin < IEnd, (IEnd - IBegin) % Increment == 0, "wrong!");
|
||||
|
||||
// not implemented
|
||||
}
|
||||
|
||||
template <index_t N, index_t X>
|
||||
__host__ __device__ auto make_uniform_sequence(Number<N>, Number<X>);
|
||||
{
|
||||
// not implemented
|
||||
}
|
||||
|
||||
template <index_t... Xs, index_t... Ys>
|
||||
__host__ __device__ constexpr auto operator+(Sequence<Xs...>, Sequence<Ys...>) const
|
||||
{
|
||||
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");
|
||||
|
||||
return Sequence<(Xs + Ys)...>{};
|
||||
}
|
||||
|
||||
template <index_t... Xs, index_t... Ys>
|
||||
__host__ __device__ constexpr auto operator-(Sequence<Xs...> seq_x, Sequence<Ys...> seq_y) const
|
||||
{
|
||||
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");
|
||||
|
||||
static_for<0, xs.GetSize(), 1>{}([&](auto I) { static_assert(seq_x.Get(I) >= seq_y.Get(I)); });
|
||||
|
||||
return Sequence<(Xs - Ys)...>{};
|
||||
}
|
||||
|
||||
template <index_t... Xs, index_t... Ys>
|
||||
__host__ __device__ constexpr auto operator*(Sequence<Xs...>, Sequence<Ys...>)const
|
||||
{
|
||||
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");
|
||||
|
||||
return Sequence<(Xs * Ys)...>{};
|
||||
}
|
||||
|
||||
template <index_t... Xs, index_t... Ys>
|
||||
__host__ __device__ constexpr auto operator/(Sequence<Xs...>, Sequence<Ys...>) const
|
||||
{
|
||||
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");
|
||||
|
||||
return Sequence<(Xs / Ys)...>{};
|
||||
}
|
||||
|
||||
template <index_t... Xs, index_t... Ys>
|
||||
__host__ __device__ constexpr auto operator%(Sequence<Xs...>, Sequence<Ys...>) const
|
||||
{
|
||||
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");
|
||||
|
||||
return Sequence<(Xs % Ys)...>{};
|
||||
}
|
||||
|
||||
template <index_t... Xs, index_t... Ys>
|
||||
__host__ __device__ constexpr auto operator%(Sequence<Xs...>, Sequence<Ys...>) const
|
||||
{
|
||||
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");
|
||||
|
||||
return Sequence<(Xs % Ys)...>{};
|
||||
}
|
||||
|
||||
template <index_t... Xs, index_t Y>
|
||||
__host__ __device__ constexpr auto operator+(Sequence<Xs...>, Number<Y>) const
|
||||
{
|
||||
return seq_x + make_uniform_sequence(Number<sizeof...(Xs)>, Number<Y>{});
|
||||
}
|
||||
|
||||
template <index_t... Xs, index_t Y>
|
||||
__host__ __device__ constexpr auto operator-(Sequence<Xs...>, Number<Y>) const
|
||||
{
|
||||
return seq_x - make_uniform_sequence(Number<sizeof...(Xs)>, Number<Y>{});
|
||||
}
|
||||
|
||||
template <index_t... Xs, index_t Y>
|
||||
__host__ __device__ constexpr auto operator*(Sequence<Xs...>, Number<Y>)const
|
||||
{
|
||||
return seq_x * make_uniform_sequence(Number<sizeof...(Xs)>, Number<Y>{});
|
||||
}
|
||||
|
||||
template <index_t... Xs, index_t Y>
|
||||
__host__ __device__ constexpr auto operator/(Sequence<Xs...>, Number<Y>) const
|
||||
{
|
||||
return seq_x / make_uniform_sequence(Number<sizeof...(Xs)>, Number<Y>{});
|
||||
}
|
||||
|
||||
template <index_t... Xs, index_t Y>
|
||||
__host__ __device__ constexpr auto operator%(Sequence<Xs...> seq_x, Number<Y> y) const
|
||||
{
|
||||
return seq_x % make_uniform_sequence(Number<sizeof...(Xs)>, Number<Y>{});
|
||||
}
|
||||
|
||||
template <index_t X, index_t... Ys>
|
||||
__host__ __device__ constexpr auto operator+(Number<X>, Sequence<Ys...>) const
|
||||
{
|
||||
return make_uniform_sequence(Number<sizeof...(Ys)>{}, Number<X>{}) + Sequence<Ys...>{};
|
||||
}
|
||||
|
||||
template <index_t X, index_t... Ys>
|
||||
__host__ __device__ constexpr auto operator-(Number<X>, Sequence<Ys...>) const
|
||||
{
|
||||
return make_uniform_sequence(Number<sizeof...(Ys)>{}, Number<X>{}) - Sequence<Ys...>{};
|
||||
}
|
||||
|
||||
template <index_t X, index_t... Ys>
|
||||
__host__ __device__ constexpr auto operator*(Number<X>, Sequence<Ys...>)const
|
||||
{
|
||||
return make_uniform_sequence(Number<sizeof...(Ys)>{}, Number<X>{}) * Sequence<Ys...>{};
|
||||
}
|
||||
|
||||
template <index_t X, index_t... Ys>
|
||||
__host__ __device__ constexpr auto operator/(Number<X>, Sequence<Ys...>) const
|
||||
{
|
||||
return make_uniform_sequence(Number<sizeof...(Ys)>{}, Number<X>{}) / Sequence<Ys...>{};
|
||||
}
|
||||
|
||||
template <index_t X, index_t... Ys>
|
||||
__host__ __device__ constexpr auto operator%(Number<X>, Sequence<Ys...>) const
|
||||
{
|
||||
return make_uniform_sequence(Number<sizeof...(Ys)>{}, Number<X>{}) % Sequence<Ys...>{};
|
||||
}
|
||||
|
||||
template <index_t I, index_t... Is>
|
||||
__host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>)
|
||||
{
|
||||
@@ -177,6 +358,12 @@ __host__ __device__ constexpr auto
|
||||
|
||||
#if 1
|
||||
// TODO: fix these mess
|
||||
template <class F, index_t... Xs>
|
||||
__host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>)
|
||||
{
|
||||
return Sequence<f(Xs)...>{};
|
||||
}
|
||||
|
||||
template <class F, index_t... Xs, index_t... Ys>
|
||||
__host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>)
|
||||
{
|
||||
@@ -248,7 +435,7 @@ __host__ __device__ constexpr auto Sequence<Is...>::PopBack() const
|
||||
}
|
||||
|
||||
template <class Seq>
|
||||
struct accumulate_on_sequence_f
|
||||
struct accumulate_on_sequence_impl
|
||||
{
|
||||
template <class IDim>
|
||||
__host__ __device__ constexpr index_t operator()(IDim) const
|
||||
@@ -262,6 +449,42 @@ __host__ __device__ constexpr index_t
|
||||
accumulate_on_sequence(Seq, Reduce, Number<I> /*initial_value*/)
|
||||
{
|
||||
constexpr index_t a =
|
||||
static_const_reduce_n<Seq::mSize>{}(accumulate_on_sequence_f<Seq>{}, Reduce{});
|
||||
static_const_reduce_n<Seq::mSize>{}(accumulate_on_sequence_impl<Seq>{}, Reduce{});
|
||||
return Reduce{}(a, I);
|
||||
}
|
||||
|
||||
template <index_t NRemain>
|
||||
struct scan_sequence_impl
|
||||
{
|
||||
template <class ScanedSeq, class RemainSeq, class Reduce>
|
||||
__host__ __device__ constexpr auto operator()(ScanedSeq, RemainSeq, Reduce) const
|
||||
{
|
||||
static_assert(RemainSeq{}.GetSize() == NRemain,
|
||||
"wrong! RemainSeq and NRemain not consistent!");
|
||||
|
||||
constexpr index_t a = Reduce{}(ScanedSeq{}.Back(), RemainSeq{}.Front());
|
||||
constexpr auto scaned_seq = ScanedSeq{}.PushBack(Number<a>{});
|
||||
|
||||
static_if<(NRemain > 1)>{}([&](auto fwd) {
|
||||
return scan_sequence_impl<NRemain - 1>{}(
|
||||
scaned_seq, RemainSeq{}.PopFront(), fwd(Reduce{}));
|
||||
}).else_([&](auto fwd) { return fwd(scaned_seq); });
|
||||
}
|
||||
};
|
||||
|
||||
template <class Seq, class Reduce>
|
||||
__host__ __device__ constexpr auto scan_sequence(Seq, Reduce)
|
||||
{
|
||||
constexpr auto scaned_seq = Sequence<Seq{}.front()>{};
|
||||
constexpr auto remain_seq = Seq{}.PopFront();
|
||||
|
||||
constexpr index_t remain_size = Seq::GetSize() - 1;
|
||||
|
||||
return scan_sequence_impl<remain_size>{}(scaned_seq, remain_seq, Reduce{});
|
||||
}
|
||||
|
||||
template <class Seq, class Reduce>
|
||||
__host__ __device__ constexpr auto reverse_scan_sequence(Seq, Reduce)
|
||||
{
|
||||
return scan_seqeunce(Seq{}.Reverse(), Reduce{}).Reverse();
|
||||
}
|
||||
|
||||
@@ -156,7 +156,7 @@ struct Blockwise3dTensorCopy3
|
||||
"wrrong! BlockSize is not big enough for ThreadPerDims!");
|
||||
|
||||
constexpr index_t num_active_thread =
|
||||
accumulate_on_sequence(ThreadPerDims{}, mod_conv::multiplies<index_t>{}, Number<1>{});
|
||||
accumulate_on_sequence(ThreadPerDims{}, std::multiplies<index_t>{}, Number<1>{});
|
||||
|
||||
if(BlockSize > num_active_thread)
|
||||
{
|
||||
|
||||
@@ -495,7 +495,7 @@ struct Blockwise4dTensorCopy3
|
||||
"wrrong! BlockSize is not big enough for ThreadPerDims!");
|
||||
|
||||
constexpr index_t num_active_thread =
|
||||
accumulate_on_sequence(ThreadPerDims{}, mod_conv::multiplies<index_t>{}, Number<1>{});
|
||||
accumulate_on_sequence(ThreadPerDims{}, std::multiplies<index_t>{}, Number<1>{});
|
||||
|
||||
if(BlockSize > num_active_thread)
|
||||
{
|
||||
|
||||
@@ -133,7 +133,7 @@ struct BlockwiseTensorSliceReorderCopy_v3
|
||||
constexpr auto thread_sub_tensor_lengths = SrcSubLengths{};
|
||||
|
||||
constexpr auto src_data_per_cluster_per_dims = transform_sequences(
|
||||
mod_conv::multiplies<index_t>{}, thread_sub_tensor_lengths, SrcClusterLengths{});
|
||||
std::multiplies<index_t>{}, thread_sub_tensor_lengths, SrcClusterLengths{});
|
||||
|
||||
constexpr auto repeat_lengths =
|
||||
transform_sequences(mod_conv::integer_divide_ceiler<index_t>{},
|
||||
@@ -141,7 +141,7 @@ struct BlockwiseTensorSliceReorderCopy_v3
|
||||
src_data_per_cluster_per_dims);
|
||||
|
||||
constexpr auto thread_tensor_lengths = transform_sequences(
|
||||
mod_conv::multiplies<index_t>{}, thread_sub_tensor_lengths, repeat_lengths);
|
||||
std::multiplies<index_t>{}, thread_sub_tensor_lengths, repeat_lengths);
|
||||
|
||||
constexpr auto thread_tensor_desc = make_ConstantTensorDescriptor(thread_tensor_lengths);
|
||||
|
||||
@@ -154,7 +154,7 @@ struct BlockwiseTensorSliceReorderCopy_v3
|
||||
constexpr auto thread_sub_tensor_lengths = SrcSubLengths{};
|
||||
|
||||
constexpr auto src_data_per_cluster_per_dims = transform_sequences(
|
||||
mod_conv::multiplies<index_t>{}, thread_sub_tensor_lengths, SrcClusterLengths{});
|
||||
std::multiplies<index_t>{}, thread_sub_tensor_lengths, SrcClusterLengths{});
|
||||
|
||||
constexpr auto repeat_lengths =
|
||||
transform_sequences(mod_conv::integer_divide_ceiler<index_t>{},
|
||||
@@ -162,7 +162,7 @@ struct BlockwiseTensorSliceReorderCopy_v3
|
||||
src_data_per_cluster_per_dims);
|
||||
|
||||
constexpr auto thread_tensor_lengths = transform_sequences(
|
||||
mod_conv::multiplies<index_t>{}, thread_sub_tensor_lengths, repeat_lengths);
|
||||
std::multiplies<index_t>{}, thread_sub_tensor_lengths, repeat_lengths);
|
||||
|
||||
constexpr auto thread_tensor_desc = make_ConstantTensorDescriptor(thread_tensor_lengths);
|
||||
|
||||
@@ -170,10 +170,10 @@ struct BlockwiseTensorSliceReorderCopy_v3
|
||||
constexpr auto repeat_multi_id = decltype(repeat_multi_id_){};
|
||||
|
||||
constexpr auto src_data_multi_id = transform_sequences(
|
||||
mod_conv::multiplies<index_t>{}, repeat_multi_id, src_data_per_cluster_per_dims);
|
||||
std::multiplies<index_t>{}, repeat_multi_id, src_data_per_cluster_per_dims);
|
||||
|
||||
constexpr auto clipboard_data_multi_id = transform_sequences(
|
||||
mod_conv::multiplies<index_t>{}, repeat_multi_id, thread_sub_tensor_lengths);
|
||||
std::multiplies<index_t>{}, repeat_multi_id, thread_sub_tensor_lengths);
|
||||
|
||||
constexpr index_t src_offset = SrcDesc{}.Get1dIndex(src_data_multi_id);
|
||||
constexpr index_t clipboard_offset =
|
||||
@@ -194,7 +194,7 @@ struct BlockwiseTensorSliceReorderCopy_v3
|
||||
constexpr auto thread_sub_tensor_lengths = SrcSubLengths{};
|
||||
|
||||
constexpr auto src_data_per_cluster_per_dims = transform_sequences(
|
||||
mod_conv::multiplies<index_t>{}, thread_sub_tensor_lengths, SrcClusterLengths{});
|
||||
std::multiplies<index_t>{}, thread_sub_tensor_lengths, SrcClusterLengths{});
|
||||
|
||||
constexpr auto repeat_lengths =
|
||||
transform_sequences(mod_conv::integer_divide_ceiler<index_t>{},
|
||||
@@ -202,7 +202,7 @@ struct BlockwiseTensorSliceReorderCopy_v3
|
||||
src_data_per_cluster_per_dims);
|
||||
|
||||
constexpr auto thread_tensor_lengths = transform_sequences(
|
||||
mod_conv::multiplies<index_t>{}, thread_sub_tensor_lengths, repeat_lengths);
|
||||
std::multiplies<index_t>{}, thread_sub_tensor_lengths, repeat_lengths);
|
||||
|
||||
constexpr auto thread_tensor_desc = make_ConstantTensorDescriptor(thread_tensor_lengths);
|
||||
|
||||
@@ -210,10 +210,10 @@ struct BlockwiseTensorSliceReorderCopy_v3
|
||||
constexpr auto repeat_multi_id = decltype(repeat_multi_id_){};
|
||||
|
||||
constexpr auto clipboard_data_multi_id = transform_sequences(
|
||||
mod_conv::multiplies<index_t>{}, repeat_multi_id, thread_sub_tensor_lengths);
|
||||
std::multiplies<index_t>{}, repeat_multi_id, thread_sub_tensor_lengths);
|
||||
|
||||
constexpr auto src_data_multi_id = transform_sequences(
|
||||
mod_conv::multiplies<index_t>{}, repeat_multi_id, src_data_per_cluster_per_dims);
|
||||
std::multiplies<index_t>{}, repeat_multi_id, src_data_per_cluster_per_dims);
|
||||
|
||||
// reorder src_data_multi_id to get dst_data_multi_id
|
||||
constexpr auto dst_data_multi_id = src_data_multi_id.ReorderGivenNew2Old(MapDst2Src{});
|
||||
|
||||
@@ -27,16 +27,10 @@ struct is_same<T, T>
|
||||
};
|
||||
|
||||
namespace mod_conv { // namespace mod_conv
|
||||
template <class T>
|
||||
struct multiplies
|
||||
template <class T, T s>
|
||||
struct scales
|
||||
{
|
||||
__host__ __device__ constexpr T operator()(T a, T b) const { return a * b; }
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct plus
|
||||
{
|
||||
__host__ __device__ constexpr T operator()(T a, T b) const { return a + b; }
|
||||
__host__ __device__ constexpr T operator()(T a) const { return s * a; }
|
||||
};
|
||||
|
||||
template <class T>
|
||||
|
||||
@@ -10,6 +10,14 @@ struct forwarder
|
||||
}
|
||||
};
|
||||
|
||||
#if 0
|
||||
template<class F>
|
||||
__host__ __device__ constexpr auto unpacker(F f)
|
||||
{
|
||||
return [=](auto xs_array){ f(xs...); };
|
||||
}
|
||||
#endif
|
||||
|
||||
// Emulate compile time if statement for C++14
|
||||
// Get the idea from
|
||||
// "https://baptiste-wicht.com/posts/2015/07/simulate-static_if-with-c11c14.html"
|
||||
@@ -87,7 +95,7 @@ struct static_for_impl<Iter, 0, Increment>
|
||||
}
|
||||
};
|
||||
|
||||
// F signature: F(Number<I>)
|
||||
// F signature: F(Number<Iter>)
|
||||
template <index_t NBegin, index_t NEnd, index_t Increment>
|
||||
struct static_for
|
||||
{
|
||||
@@ -97,9 +105,8 @@ struct static_for
|
||||
static_assert((NEnd - NBegin) % Increment == 0,
|
||||
"Wrong! should satisfy (NEnd - NBegin) % Increment == 0");
|
||||
|
||||
static_if < NBegin<End>{}([&](auto forwarder) {
|
||||
static_for_impl<NBegin, NEnd - NBegin, forwarder(Increment)>{}(f);
|
||||
});
|
||||
static_if<(NBegin < End)>{}(
|
||||
[&](auto fwd) { static_for_impl<NBegin, NEnd - NBegin, fwd(Increment)>{}(f); });
|
||||
}
|
||||
};
|
||||
|
||||
@@ -127,11 +134,3 @@ struct static_const_reduce_n<1>
|
||||
return f(Number<0>{});
|
||||
}
|
||||
};
|
||||
|
||||
#if 0
|
||||
template<class F>
|
||||
__host__ __device__ constexpr auto unpacker(F f)
|
||||
{
|
||||
return [=](auto xs_array){ f(xs...); };
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -91,7 +91,7 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
|
||||
// input tensor
|
||||
// memory layout descriptor in device memory [N0, N1, N2, C, H, W]
|
||||
constexpr auto in_n0_n1_n2_c_h_w_global_mem_desc =
|
||||
in_n_c_h_w_global_desc.Fold(I0, Sequence<N1, N2>{});
|
||||
in_n_c_h_w_global_desc.Fold(I0, Number<N1>{}, Number<N2>{});
|
||||
|
||||
// merged tensor descriptor in device memory [N1, N2, C, B], src of blockwise copy
|
||||
constexpr auto in_n1_n2_c_b_global_merged_desc =
|
||||
@@ -132,7 +132,7 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
|
||||
|
||||
// weight tensor
|
||||
// tensor descriptor in device memory, src of blockwise copy
|
||||
constexpr auto wei_c_k_global_desc = wei_c_y_x_k_global_desc.Extract(Sequence<0, 3>{});
|
||||
constexpr auto wei_c_k_global_desc = wei_c_y_x_k_global_desc.Extract(I0, I3);
|
||||
|
||||
// tensor descriptor in LDS, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
@@ -257,7 +257,8 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
|
||||
|
||||
// output memory layout descriptor in device memory
|
||||
constexpr auto out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc =
|
||||
out_n_k_h_w_global.Fold(I1, Sequence<K1, K2>{}).Fold(I0, Sequence<N1, N2>{});
|
||||
out_n_k_h_w_global.Fold(I1, Number<K1>{}, Number<K2>{})
|
||||
.Fold(I0, Number<N1>{}, Number<N2>{});
|
||||
|
||||
// output merged tensor descriptor in device memory, dst of threadwise copy
|
||||
constexpr auto out_k0_k1_k2_n1_b_n2_global_merged_desc =
|
||||
|
||||
Reference in New Issue
Block a user