mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 01:10:17 +00:00
use more constexpr
This commit is contained in:
@@ -443,7 +443,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
constexpr index_t HPad = 0;
|
||||
constexpr index_t WPad = 0;
|
||||
#elif 0
|
||||
#elif 1
|
||||
// 3x3 filter, 28x28 image
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 256;
|
||||
@@ -455,7 +455,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
constexpr index_t HPad = 0;
|
||||
constexpr index_t WPad = 0;
|
||||
#elif 1
|
||||
#elif 0
|
||||
// 1x1 filter, 28x28 image
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 512;
|
||||
@@ -549,12 +549,24 @@ int main(int argc, char* argv[])
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
constexpr index_t HPad = 0;
|
||||
constexpr index_t WPad = 0;
|
||||
#elif 0
|
||||
// 1x1 filter, 7x7 image
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 512;
|
||||
constexpr index_t HI = 7;
|
||||
constexpr index_t WI = 7;
|
||||
constexpr index_t K = 2048;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
constexpr index_t HPad = 0;
|
||||
constexpr index_t WPad = 0;
|
||||
#elif 0
|
||||
// 1x1 filter, 73x73 image
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 64;
|
||||
constexpr index_t C = 512;
|
||||
constexpr index_t HI = 73;
|
||||
constexpr index_t WI = 73;
|
||||
constexpr index_t K = 128;
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#pragma once
|
||||
#include "Sequence.hip.hpp"
|
||||
#include "functional.hip.hpp"
|
||||
#include "functional2.hip.hpp"
|
||||
|
||||
template <class TData, index_t NSize>
|
||||
struct Array
|
||||
@@ -25,14 +25,17 @@ struct Array
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr TData Get(Number<I>) const
|
||||
{
|
||||
static_assert(I < NSize, "wrong!");
|
||||
|
||||
return mData[I];
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr bool Set(Number<I>, TData x)
|
||||
__host__ __device__ constexpr void Set(Number<I>, TData x)
|
||||
{
|
||||
static_assert(I < NSize, "wrong!");
|
||||
|
||||
mData[I] = x;
|
||||
return true; // for constexpr
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr auto PushBack(TData x) const
|
||||
@@ -59,6 +62,7 @@ __host__ __device__ constexpr auto sequence2array(Sequence<Is...>)
|
||||
template <class TData, index_t NSize>
|
||||
__host__ __device__ constexpr auto make_zero_array()
|
||||
{
|
||||
#if 0
|
||||
Array<TData, NSize> a;
|
||||
|
||||
static_for<0, NSize, 1>{}([&](auto I) {
|
||||
@@ -67,6 +71,11 @@ __host__ __device__ constexpr auto make_zero_array()
|
||||
});
|
||||
|
||||
return a;
|
||||
#else
|
||||
constexpr auto zero_sequence = typename uniform_sequence_gen<NSize, 0>::SeqType{};
|
||||
constexpr auto zero_array = sequence2array(zero_sequence);
|
||||
return zero_array;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class TData, index_t NSize, index_t... IRs>
|
||||
@@ -85,6 +94,7 @@ __host__ __device__ constexpr auto reorder_array_given_new2old(const Array<TData
|
||||
return new_array;
|
||||
}
|
||||
|
||||
#if 0
|
||||
template <class TData, index_t NSize, index_t... IRs>
|
||||
__host__ __device__ constexpr auto reorder_array_given_old2new(const Array<TData, NSize>& old_array,
|
||||
Sequence<IRs...> old2new)
|
||||
@@ -100,6 +110,45 @@ __host__ __device__ constexpr auto reorder_array_given_old2new(const Array<TData
|
||||
|
||||
return new_array;
|
||||
}
|
||||
#else
|
||||
template <class TData, index_t NSize, class MapOld2New>
|
||||
struct reorder_array_given_old2new_impl
|
||||
{
|
||||
const Array<TData, NSize>& old_array_ref;
|
||||
Array<TData, NSize>& new_array_ref;
|
||||
|
||||
__host__
|
||||
__device__ constexpr reorder_array_given_old2new_impl(const Array<TData, NSize>& old_array,
|
||||
Array<TData, NSize>& new_array)
|
||||
: old_array_ref(old_array), new_array_ref(new_array)
|
||||
{
|
||||
}
|
||||
|
||||
template <index_t IOldDim>
|
||||
__host__ __device__ constexpr void operator()(Number<IOldDim>) const
|
||||
{
|
||||
TData old_data = old_array_ref.Get(Number<IOldDim>{});
|
||||
|
||||
constexpr index_t INewDim = MapOld2New::Get(Number<IOldDim>{});
|
||||
|
||||
new_array_ref.Set(Number<INewDim>{}, old_data);
|
||||
}
|
||||
};
|
||||
|
||||
template <class TData, index_t NSize, index_t... IRs>
|
||||
__host__ __device__ constexpr auto reorder_array_given_old2new(const Array<TData, NSize>& old_array,
|
||||
Sequence<IRs...> old2new)
|
||||
{
|
||||
Array<TData, NSize> new_array;
|
||||
|
||||
static_assert(NSize == sizeof...(IRs), "NSize not consistent");
|
||||
|
||||
static_for<0, NSize, 1>{}(
|
||||
reorder_array_given_old2new_impl<TData, NSize, Sequence<IRs...>>(old_array, new_array));
|
||||
|
||||
return new_array;
|
||||
}
|
||||
#endif
|
||||
|
||||
template <class TData, index_t NSize, class ExtractSeq>
|
||||
__host__ __device__ constexpr auto extract_array(const Array<TData, NSize>& old_array, ExtractSeq)
|
||||
|
||||
@@ -115,15 +115,13 @@ struct ConstantMergedTensorDescriptor
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
constexpr __host__ __device__ bool operator()(Number<I>) const
|
||||
__host__ __device__ constexpr void operator()(Number<I>) const
|
||||
{
|
||||
constexpr index_t idim_original = OriginalDimsPartial::Get(Number<I>{});
|
||||
|
||||
index_t itmp = original_multi_id_partial_ref.Get(Number<I>{});
|
||||
|
||||
original_multi_id_ref.Set(Number<idim_original>{}, itmp);
|
||||
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -139,7 +137,7 @@ struct ConstantMergedTensorDescriptor
|
||||
}
|
||||
|
||||
template <index_t IDim>
|
||||
constexpr __host__ __device__ bool operator()(Number<IDim>) const
|
||||
__host__ __device__ constexpr void operator()(Number<IDim>) const
|
||||
{
|
||||
constexpr auto original_dims_partial =
|
||||
std::get<IDim>(std::tuple<OriginalDimMergeSeqs...>{});
|
||||
@@ -152,11 +150,10 @@ struct ConstantMergedTensorDescriptor
|
||||
static_for<0, original_dims_partial.GetSize(), 1>{}(
|
||||
GetOriginalMultiIndexFromMultiIndex_impl1<decltype(original_dims_partial)>(
|
||||
original_multi_id_partial, original_multi_id_ref));
|
||||
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
// return type is Array<...>
|
||||
__host__ __device__ static constexpr auto
|
||||
GetOriginalMultiIndexFromMultiIndex(Array<index_t, nDim> multi_id)
|
||||
{
|
||||
@@ -179,16 +176,6 @@ struct ConstantMergedTensorDescriptor
|
||||
}
|
||||
#endif
|
||||
|
||||
#if 0
|
||||
// return type is Sequence<...>
|
||||
template <index_t... Is>
|
||||
__host__ __device__ static constexpr auto GetOriginalMultiIndexFromMultiIndex(Sequence<Is...>)
|
||||
{
|
||||
// not implemented
|
||||
return Sequence<>{};
|
||||
}
|
||||
#endif
|
||||
|
||||
__host__ __device__ static constexpr index_t
|
||||
GetOffsetFromMultiIndex(Array<index_t, nDim> multi_id)
|
||||
{
|
||||
|
||||
@@ -37,10 +37,11 @@ struct Sequence
|
||||
template <class MapOld2New>
|
||||
__host__ __device__ static constexpr auto ReorderGivenOld2New(MapOld2New /*old2new*/)
|
||||
{
|
||||
#if 0
|
||||
static_assert(is_same<sequence_sort<MapOld2New>::SortedSeqType,
|
||||
arithmetic_sequence_gen<0, mSize, 1>::SeqType>::value,
|
||||
"wrong! invalid old2new map");
|
||||
|
||||
#endif
|
||||
constexpr auto map_new2old = typename sequence_map_inverse<MapOld2New>::SeqMapType{};
|
||||
|
||||
return ReorderGivenNew2Old(map_new2old);
|
||||
@@ -99,6 +100,7 @@ struct Sequence
|
||||
__host__ __device__ static constexpr auto Modify(Number<I>, Number<X>);
|
||||
};
|
||||
|
||||
// merge sequence
|
||||
template <class, class>
|
||||
struct sequence_merge;
|
||||
|
||||
@@ -108,6 +110,7 @@ struct sequence_merge<Sequence<Xs...>, Sequence<Ys...>>
|
||||
using SeqType = Sequence<Xs..., Ys...>;
|
||||
};
|
||||
|
||||
// arithmetic sqeuence
|
||||
template <index_t IBegin, index_t NSize, index_t Increment>
|
||||
struct arithmetic_sequence_gen_impl
|
||||
{
|
||||
@@ -139,7 +142,31 @@ struct arithmetic_sequence_gen
|
||||
typename arithmetic_sequence_gen_impl<IBegin, IEnd - IBegin, Increment>::SeqType;
|
||||
};
|
||||
|
||||
// reverse scan with init
|
||||
// transform sequence
|
||||
template <class, class>
|
||||
struct sequence_transform;
|
||||
|
||||
template <class F, index_t... Is>
|
||||
struct sequence_transform<F, Sequence<Is...>>
|
||||
{
|
||||
using SeqType = Sequence<F{}(Is)...>;
|
||||
};
|
||||
|
||||
// uniform sequence
|
||||
template <index_t NSize, index_t I>
|
||||
struct uniform_sequence_gen
|
||||
{
|
||||
struct return_constant
|
||||
{
|
||||
__host__ __device__ constexpr index_t operator()(index_t) const { return I; }
|
||||
};
|
||||
|
||||
using SeqType = typename sequence_transform<
|
||||
return_constant,
|
||||
typename arithmetic_sequence_gen<0, NSize, 1>::SeqType>::SeqType;
|
||||
};
|
||||
|
||||
// reverse inclusive scan (with init) sequence
|
||||
template <class, class, index_t>
|
||||
struct sequence_reverse_inclusive_scan;
|
||||
|
||||
@@ -166,22 +193,7 @@ struct sequence_reverse_inclusive_scan<Sequence<>, Reduce, Init>
|
||||
using SeqType = Sequence<>;
|
||||
};
|
||||
|
||||
#if 0
|
||||
// reverse scan with token
|
||||
template <class, class, index_t>
|
||||
struct sequence_reverse_inclusive_token_scan;
|
||||
|
||||
template <index_t I, index_t... Is, class F, index_t Token>
|
||||
struct sequence_reverse_inclusive_token_scan<Sequence<I, Is...>, F, Token>
|
||||
{
|
||||
using old_scan = typename sequence_reverse_inclusive_token_scan<Sequence<Is...>, F, Token>::SeqType;
|
||||
|
||||
static constexpr index_t new_reduce = Reduce{}(I, old_scan{}.Front());
|
||||
|
||||
using SeqType = typename sequence_merge<Sequence<new_reduce>, old_scan>::SeqType;
|
||||
};
|
||||
#endif
|
||||
|
||||
// extract sequence
|
||||
template <class, class>
|
||||
struct sequence_extract;
|
||||
|
||||
@@ -191,6 +203,7 @@ struct sequence_extract<Seq, Sequence<Is...>>
|
||||
using SeqType = Sequence<Seq{}.Get(Number<Is>{})...>;
|
||||
};
|
||||
|
||||
// split sequence
|
||||
template <class Seq, index_t I>
|
||||
struct sequence_split
|
||||
{
|
||||
@@ -203,6 +216,7 @@ struct sequence_split
|
||||
using SeqType1 = typename sequence_extract<Seq, range1>::SeqType;
|
||||
};
|
||||
|
||||
// reverse sequence
|
||||
template <class Seq>
|
||||
struct sequence_reverse
|
||||
{
|
||||
@@ -308,8 +322,10 @@ __host__ __device__ constexpr auto operator-(Sequence<Xs...> seq_x, Sequence<Ys.
|
||||
{
|
||||
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");
|
||||
|
||||
#if 0
|
||||
static_for<0, seq_x.GetSize(), 1>{}(
|
||||
[&](auto I) { static_assert(seq_x.Get(I) >= seq_y.Get(I), "wrong! going to undeflow"); });
|
||||
#endif
|
||||
|
||||
return Sequence<(Xs - Ys)...>{};
|
||||
}
|
||||
@@ -388,10 +404,12 @@ __host__ __device__ constexpr auto operator-(Number<Y>, Sequence<Xs...>)
|
||||
{
|
||||
constexpr auto seq_x = Sequence<Xs...>{};
|
||||
|
||||
#if 0
|
||||
static_for<0, sizeof...(Xs), 1>{}([&](auto Iter) {
|
||||
constexpr auto I = decltype(Iter){};
|
||||
static_assert(seq_x.Get(I) <= Y, "wrong! going to underflow");
|
||||
});
|
||||
#endif
|
||||
|
||||
return Sequence<(Y - Xs)...>{};
|
||||
}
|
||||
|
||||
@@ -256,6 +256,7 @@ struct BlockwiseGenericTensorSliceCopy_v1
|
||||
make_ConstantTensorDescriptor_packed(thread_sub_tensor_lengths * repeat_lengths);
|
||||
|
||||
static_ford<decltype(repeat_lengths)>{}([&](auto repeat_multi_id_) {
|
||||
#if 0
|
||||
constexpr auto repeat_multi_id = sequence2array(decltype(repeat_multi_id_){});
|
||||
|
||||
const auto clipboard_data_multi_id_begin =
|
||||
@@ -269,6 +270,18 @@ struct BlockwiseGenericTensorSliceCopy_v1
|
||||
|
||||
const index_t dst_offset = DstDesc{}.GetOffsetFromMultiIndex(
|
||||
dst_data_multi_id_begin); // cannot not constexpr, why?
|
||||
#else
|
||||
constexpr auto clipboard_data_multi_id_begin =
|
||||
repeat_multi_id_ * thread_sub_tensor_lengths;
|
||||
|
||||
constexpr auto dst_data_multi_id_begin = repeat_multi_id_ * data_per_cluster_per_dims;
|
||||
|
||||
constexpr index_t clipboard_offset =
|
||||
thread_tensor_desc.GetOffsetFromMultiIndex(clipboard_data_multi_id_begin);
|
||||
|
||||
constexpr index_t dst_offset =
|
||||
DstDesc{}.GetOffsetFromMultiIndex(dst_data_multi_id_begin);
|
||||
#endif
|
||||
|
||||
threadwise_generic_tensor_slice_copy_v1(thread_tensor_desc,
|
||||
p_clipboard + clipboard_offset,
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#include "Array.hip.hpp"
|
||||
#include "functional.hip.hpp"
|
||||
#include "functional2.hip.hpp"
|
||||
#include "functional3.hip.hpp"
|
||||
|
||||
#if USE_AMD_INLINE_ASM
|
||||
#include "amd_inline_asm.hip.hpp"
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#pragma once
|
||||
#include "integral_constant.hip.hpp"
|
||||
#include "Sequence.hip.hpp"
|
||||
|
||||
struct forwarder
|
||||
{
|
||||
@@ -10,6 +11,14 @@ struct forwarder
|
||||
}
|
||||
};
|
||||
|
||||
struct swallow
|
||||
{
|
||||
template <class... Ts>
|
||||
__host__ __device__ constexpr swallow(Ts&&... ts)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
#if 0
|
||||
template<class F>
|
||||
__host__ __device__ constexpr auto unpacker(F f)
|
||||
@@ -72,51 +81,6 @@ struct static_if<false>
|
||||
return Type{};
|
||||
}
|
||||
};
|
||||
template <index_t Iter, index_t Remaining, index_t Increment>
|
||||
struct static_for_impl
|
||||
{
|
||||
template <class F>
|
||||
constexpr __host__ __device__ void operator()(F f) const
|
||||
{
|
||||
static_assert(Remaining % Increment == 0, "wrong! Remaining % Increment != 0");
|
||||
static_assert(Increment <= Remaining, "will go out-of-range");
|
||||
|
||||
f(Number<Iter>{});
|
||||
static_for_impl<Iter + Increment, Remaining - Increment, Increment>{}(f);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t Iter, index_t Increment>
|
||||
struct static_for_impl<Iter, 0, Increment>
|
||||
{
|
||||
template <class F>
|
||||
constexpr __host__ __device__ void operator()(F) const
|
||||
{
|
||||
// no work left, just return
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// F signature: F(Number<Iter>)
|
||||
template <index_t NBegin, index_t NEnd, index_t Increment>
|
||||
struct static_for
|
||||
{
|
||||
template <class F>
|
||||
constexpr __host__ __device__ void operator()(F f) const
|
||||
{
|
||||
static_assert(NBegin <= NEnd, "wrongs! should have NBegin <= NEnd");
|
||||
|
||||
static_assert((NEnd - NBegin) % Increment == 0,
|
||||
"Wrong! should satisfy (NEnd - NBegin) % Increment == 0");
|
||||
|
||||
#if 0
|
||||
static_if<(NBegin < NEnd)>{}(
|
||||
[&](auto fwd) { static_for_impl<NBegin, NEnd - NBegin, fwd(Increment)>{}(f); });
|
||||
#else
|
||||
static_for_impl<NBegin, NEnd - NBegin, Increment>{}(f);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t NLoop>
|
||||
struct static_const_reduce_n
|
||||
|
||||
@@ -1,106 +1,80 @@
|
||||
#pragma once
|
||||
#include "functional.hip.hpp"
|
||||
#include "Sequence.hip.hpp"
|
||||
|
||||
// RemainLengths: Sequence<...>
|
||||
template <class RemainLengths>
|
||||
struct static_ford_impl
|
||||
#if 0
|
||||
template <index_t Iter, index_t Remaining, index_t Increment>
|
||||
struct static_for_impl
|
||||
{
|
||||
// F signature: F(Sequence<...> multi_id)
|
||||
// CurrentMultiIndex: Sequence<...>
|
||||
template <class F, class CurrentMultiIndex>
|
||||
__host__ __device__ void operator()(F f, CurrentMultiIndex) const
|
||||
{
|
||||
static_assert(RemainLengths::GetSize() > 0, "wrong! should not get here");
|
||||
|
||||
static_for<0, RemainLengths::Front(), 1>{}([=](auto I) {
|
||||
static_ford_impl<decltype(RemainLengths::PopFront())>{}(f,
|
||||
CurrentMultiIndex::PushBack(I));
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct static_ford_impl<Sequence<>>
|
||||
{
|
||||
// F signature: F(Sequence<...> multi_id)
|
||||
// CurrentMultiIndex: Sequence<...>
|
||||
template <class F, class CurrentMultiIndex>
|
||||
__host__ __device__ void operator()(F f, CurrentMultiIndex) const
|
||||
{
|
||||
f(CurrentMultiIndex{});
|
||||
}
|
||||
};
|
||||
|
||||
// Lengths is Sequence<...>
|
||||
template <class Lengths>
|
||||
struct static_ford
|
||||
{
|
||||
// F signature: F(Sequence<...> multi_id)
|
||||
template <class F>
|
||||
__host__ __device__ void operator()(F f) const
|
||||
constexpr __host__ __device__ void operator()(F f) const
|
||||
{
|
||||
static_assert(Lengths::GetSize() > 0, "wrong! Lengths is empty");
|
||||
static_assert(Remaining % Increment == 0, "wrong! Remaining % Increment != 0");
|
||||
static_assert(Increment <= Remaining, "will go out-of-range");
|
||||
|
||||
static_ford_impl<Lengths>{}(f, Sequence<>{});
|
||||
f(Number<Iter>{});
|
||||
static_for_impl<Iter + Increment, Remaining - Increment, Increment>{}(f);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t RemainDim>
|
||||
struct ford_impl
|
||||
template <index_t Iter, index_t Increment>
|
||||
struct static_for_impl<Iter, 0, Increment>
|
||||
{
|
||||
// F signature: F(Array<...> multi_id)
|
||||
// CurrentMultiIndex: Array<...>
|
||||
// RemainLengths: Sequence<...>
|
||||
template <class F, class CurrentMultiIndex, class RemainLengths>
|
||||
__host__ __device__ void
|
||||
operator()(F f, CurrentMultiIndex current_multi_id, RemainLengths) const
|
||||
{
|
||||
static_assert(RemainLengths::GetSize() == RemainDim, "wrong!");
|
||||
static_assert(RemainDim > 1, "wrong!");
|
||||
|
||||
constexpr auto next_length = RemainLengths{}.Front();
|
||||
|
||||
for(index_t i = 0; i < next_length; ++i)
|
||||
{
|
||||
ford_impl<RemainDim - 1>{}(f, current_multi_id.PushBack(i), RemainLengths{}.PopFront());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ford_impl<1>
|
||||
{
|
||||
// F signature: F(Array<...> multi_id)
|
||||
// CurrentMultiIndex: Array<...>
|
||||
// RemainLengths: Sequence<...>
|
||||
template <class F, class CurrentMultiIndex, class RemainLengths>
|
||||
__host__ __device__ void
|
||||
operator()(F f, CurrentMultiIndex current_multi_id, RemainLengths) const
|
||||
{
|
||||
static_assert(RemainLengths::GetSize() == 1, "wrong!");
|
||||
|
||||
constexpr index_t last_length = RemainLengths{}.Front();
|
||||
|
||||
for(index_t i = 0; i < last_length; ++i)
|
||||
{
|
||||
f(current_multi_id.PushBack(i));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Lengths is Sequence<...>
|
||||
template <class Lengths>
|
||||
struct ford
|
||||
{
|
||||
// F signature: F(Array<...> multi_id)
|
||||
template <class F>
|
||||
__host__ __device__ void operator()(F f) const
|
||||
constexpr __host__ __device__ void operator()(F) const
|
||||
{
|
||||
constexpr index_t first_length = Lengths{}.Front();
|
||||
|
||||
for(index_t i = 0; i < first_length; ++i)
|
||||
{
|
||||
ford_impl<Lengths::GetSize() - 1>{}(f, Array<index_t, 1>{i}, Lengths{}.PopFront());
|
||||
}
|
||||
// no work left, just return
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// F signature: F(Number<Iter>)
|
||||
template <index_t NBegin, index_t NEnd, index_t Increment>
|
||||
struct static_for
|
||||
{
|
||||
template <class F>
|
||||
constexpr __host__ __device__ void operator()(F f) const
|
||||
{
|
||||
static_assert(NBegin <= NEnd, "wrongs! should have NBegin <= NEnd");
|
||||
|
||||
static_assert((NEnd - NBegin) % Increment == 0,
|
||||
"Wrong! should satisfy (NEnd - NBegin) % Increment == 0");
|
||||
|
||||
#if 0
|
||||
static_if<(NBegin < NEnd)>{}(
|
||||
[&](auto fwd) { static_for_impl<NBegin, NEnd - NBegin, fwd(Increment)>{}(f); });
|
||||
#else
|
||||
static_for_impl<NBegin, NEnd - NBegin, Increment>{}(f);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
#else
|
||||
template <class>
|
||||
struct static_for_impl;
|
||||
|
||||
template <index_t... Is>
|
||||
struct static_for_impl<Sequence<Is...>>
|
||||
{
|
||||
template <class F>
|
||||
__host__ __device__ constexpr void operator()(F f) const
|
||||
{
|
||||
swallow{(f(Number<Is>{}), 0)...};
|
||||
}
|
||||
};
|
||||
|
||||
// F signature: F(Number<Iter>)
|
||||
template <index_t NBegin, index_t NEnd, index_t Increment>
|
||||
struct static_for
|
||||
{
|
||||
template <class F>
|
||||
__host__ __device__ constexpr void operator()(F f) const
|
||||
{
|
||||
static_assert(NBegin <= NEnd, "wrongs! should have NBegin <= NEnd");
|
||||
|
||||
static_assert((NEnd - NBegin) % Increment == 0,
|
||||
"Wrong! should satisfy (NEnd - NBegin) % Increment == 0");
|
||||
|
||||
static_for_impl<typename arithmetic_sequence_gen<NBegin, NEnd, Increment>::SeqType>{}(f);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
109
src/include/functional3.hip.hpp
Normal file
109
src/include/functional3.hip.hpp
Normal file
@@ -0,0 +1,109 @@
|
||||
#pragma once
|
||||
#include "functional.hip.hpp"
|
||||
#include "functional2.hip.hpp"
|
||||
#include "Sequence.hip.hpp"
|
||||
#include "Array.hip.hpp"
|
||||
|
||||
// RemainLengths: Sequence<...>
|
||||
template <class RemainLengths>
|
||||
struct static_ford_impl
|
||||
{
|
||||
// F signature: F(Sequence<...> multi_id)
|
||||
// CurrentMultiIndex: Sequence<...>
|
||||
template <class F, class CurrentMultiIndex>
|
||||
__host__ __device__ void operator()(F f, CurrentMultiIndex) const
|
||||
{
|
||||
static_assert(RemainLengths::GetSize() > 0, "wrong! should not get here");
|
||||
|
||||
static_for<0, RemainLengths::Front(), 1>{}([=](auto I) {
|
||||
static_ford_impl<decltype(RemainLengths::PopFront())>{}(f,
|
||||
CurrentMultiIndex::PushBack(I));
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct static_ford_impl<Sequence<>>
|
||||
{
|
||||
// F signature: F(Sequence<...> multi_id)
|
||||
// CurrentMultiIndex: Sequence<...>
|
||||
template <class F, class CurrentMultiIndex>
|
||||
__host__ __device__ void operator()(F f, CurrentMultiIndex) const
|
||||
{
|
||||
f(CurrentMultiIndex{});
|
||||
}
|
||||
};
|
||||
|
||||
// Lengths is Sequence<...>
|
||||
template <class Lengths>
|
||||
struct static_ford
|
||||
{
|
||||
// F signature: F(Sequence<...> multi_id)
|
||||
template <class F>
|
||||
__host__ __device__ void operator()(F f) const
|
||||
{
|
||||
static_assert(Lengths::GetSize() > 0, "wrong! Lengths is empty");
|
||||
|
||||
static_ford_impl<Lengths>{}(f, Sequence<>{});
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t RemainDim>
|
||||
struct ford_impl
|
||||
{
|
||||
// F signature: F(Array<...> multi_id)
|
||||
// CurrentMultiIndex: Array<...>
|
||||
// RemainLengths: Sequence<...>
|
||||
template <class F, class CurrentMultiIndex, class RemainLengths>
|
||||
__host__ __device__ void
|
||||
operator()(F f, CurrentMultiIndex current_multi_id, RemainLengths) const
|
||||
{
|
||||
static_assert(RemainLengths::GetSize() == RemainDim, "wrong!");
|
||||
static_assert(RemainDim > 1, "wrong!");
|
||||
|
||||
constexpr auto next_length = RemainLengths{}.Front();
|
||||
|
||||
for(index_t i = 0; i < next_length; ++i)
|
||||
{
|
||||
ford_impl<RemainDim - 1>{}(f, current_multi_id.PushBack(i), RemainLengths{}.PopFront());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ford_impl<1>
|
||||
{
|
||||
// F signature: F(Array<...> multi_id)
|
||||
// CurrentMultiIndex: Array<...>
|
||||
// RemainLengths: Sequence<...>
|
||||
template <class F, class CurrentMultiIndex, class RemainLengths>
|
||||
__host__ __device__ void
|
||||
operator()(F f, CurrentMultiIndex current_multi_id, RemainLengths) const
|
||||
{
|
||||
static_assert(RemainLengths::GetSize() == 1, "wrong!");
|
||||
|
||||
constexpr index_t last_length = RemainLengths{}.Front();
|
||||
|
||||
for(index_t i = 0; i < last_length; ++i)
|
||||
{
|
||||
f(current_multi_id.PushBack(i));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Lengths is Sequence<...>
|
||||
template <class Lengths>
|
||||
struct ford
|
||||
{
|
||||
// F signature: F(Array<...> multi_id)
|
||||
template <class F>
|
||||
__host__ __device__ void operator()(F f) const
|
||||
{
|
||||
constexpr index_t first_length = Lengths{}.Front();
|
||||
|
||||
for(index_t i = 0; i < first_length; ++i)
|
||||
{
|
||||
ford_impl<Lengths::GetSize() - 1>{}(f, Array<index_t, 1>{i}, Lengths{}.PopFront());
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -246,7 +246,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
|
||||
|
||||
// choose GEMM implementation here
|
||||
const auto run_blockwise_gemm = [&](auto... Xs) {
|
||||
#if 0
|
||||
#if 1
|
||||
return blockwise_gemm.Run(Xs...);
|
||||
#else
|
||||
return blockwise_gemm.Run_asm(Xs...);
|
||||
|
||||
@@ -77,14 +77,14 @@ __device__ void threadwise_generic_tensor_slice_copy_v1(
|
||||
*reinterpret_cast<const vector_t*>(&p_src[src_index]);
|
||||
});
|
||||
#else
|
||||
static_ford<decltype(access_lengths)>{}([&](auto access_multi_id_) {
|
||||
const auto access_multi_id = sequence2array(access_multi_id_);
|
||||
static_ford<decltype(access_lengths)>{}([&](auto access_multi_id) {
|
||||
constexpr index_t itmp = access_multi_id.Back() * DataPerAccess;
|
||||
|
||||
auto data_multi_id_in_access_order = access_multi_id;
|
||||
data_multi_id_in_access_order[nDim - 1] = access_multi_id[nDim - 1] * DataPerAccess;
|
||||
constexpr auto data_multi_id_in_access_order =
|
||||
access_multi_id.Modify(Number<nDim - 1>{}, Number<itmp>{});
|
||||
|
||||
const auto data_multi_id =
|
||||
reorder_array_given_old2new(data_multi_id_in_access_order, DimAccessOrder{});
|
||||
constexpr auto data_multi_id = reorder_array_given_old2new(
|
||||
sequence2array(data_multi_id_in_access_order), DimAccessOrder{});
|
||||
|
||||
const index_t src_index =
|
||||
SrcDesc::GetOffsetFromMultiIndex(src_multi_id_begin + data_multi_id);
|
||||
|
||||
Reference in New Issue
Block a user