mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
use more constexpr for Array
This commit is contained in:
@@ -34,14 +34,6 @@ struct Array
|
||||
|
||||
__host__ __device__ TData& operator()(index_t i) { return mData[i]; }
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr TData Get(Number<I>) const
|
||||
{
|
||||
static_assert(I < NSize, "wrong!");
|
||||
|
||||
return mData[I];
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr void Set(Number<I>, TData x)
|
||||
{
|
||||
@@ -50,16 +42,33 @@ struct Array
|
||||
mData[I] = x;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr void Set(index_t I, TData x) { mData[I] = x; }
|
||||
|
||||
struct lambda_PushBack // emulate constexpr lambda
|
||||
{
|
||||
const Array<TData, NSize>& old_array;
|
||||
Array<TData, NSize + 1>& new_array;
|
||||
|
||||
__host__ __device__ constexpr lambda_PushBack(const Array<TData, NSize>& old_array_,
|
||||
Array<TData, NSize + 1>& new_array_)
|
||||
: old_array(old_array_), new_array(new_array_)
|
||||
{
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr void operator()(Number<I>) const
|
||||
{
|
||||
new_array.Set(Number<I>{}, old_array[I]);
|
||||
}
|
||||
};
|
||||
|
||||
__host__ __device__ constexpr auto PushBack(TData x) const
|
||||
{
|
||||
Array<TData, NSize + 1> new_array;
|
||||
|
||||
static_for<0, NSize, 1>{}([&](auto I) {
|
||||
constexpr index_t i = I.Get();
|
||||
new_array(i) = mData[i];
|
||||
});
|
||||
static_for<0, NSize, 1>{}(lambda_PushBack(*this, new_array));
|
||||
|
||||
new_array(NSize) = x;
|
||||
new_array.Set(Number<NSize>{}, x);
|
||||
|
||||
return new_array;
|
||||
}
|
||||
@@ -81,18 +90,13 @@ __host__ __device__ constexpr auto make_zero_array()
|
||||
|
||||
template <class TData, index_t NSize, index_t... IRs>
|
||||
__host__ __device__ constexpr auto reorder_array_given_new2old(const Array<TData, NSize>& old_array,
|
||||
Sequence<IRs...> new2old)
|
||||
Sequence<IRs...> /*new2old*/)
|
||||
{
|
||||
Array<TData, NSize> new_array;
|
||||
|
||||
static_assert(NSize == sizeof...(IRs), "NSize not consistent");
|
||||
|
||||
static_for<0, NSize, 1>{}([&](auto IDim) {
|
||||
constexpr index_t idim = IDim.Get();
|
||||
new_array[idim] = old_array[new2old.Get(IDim)];
|
||||
});
|
||||
static_assert(is_valid_sequence_map<Sequence<IRs...>>::value, "wrong! invalid reorder map");
|
||||
|
||||
return new_array;
|
||||
return Array<TData, NSize>{old_array.mSize[IRs]...};
|
||||
}
|
||||
|
||||
template <class TData, index_t NSize, class MapOld2New>
|
||||
@@ -120,12 +124,14 @@ struct lambda_reorder_array_given_old2new
|
||||
|
||||
template <class TData, index_t NSize, index_t... IRs>
|
||||
__host__ __device__ constexpr auto reorder_array_given_old2new(const Array<TData, NSize>& old_array,
|
||||
Sequence<IRs...> old2new)
|
||||
Sequence<IRs...> /*old2new*/)
|
||||
{
|
||||
Array<TData, NSize> new_array;
|
||||
|
||||
static_assert(NSize == sizeof...(IRs), "NSize not consistent");
|
||||
|
||||
static_assert(is_valid_sequence_map<Sequence<IRs...>>::value, "wrong! invalid reorder map");
|
||||
|
||||
static_for<0, NSize, 1>{}(
|
||||
lambda_reorder_array_given_old2new<TData, NSize, Sequence<IRs...>>(old_array, new_array));
|
||||
|
||||
@@ -141,25 +147,44 @@ __host__ __device__ constexpr auto extract_array(const Array<TData, NSize>& old_
|
||||
|
||||
static_assert(new_size <= NSize, "wrong! too many extract");
|
||||
|
||||
static_for<0, new_size, 1>{}([&](auto I) {
|
||||
constexpr index_t i = I.Get();
|
||||
new_array(i) = old_array[ExtractSeq::Get(I)];
|
||||
});
|
||||
static_for<0, new_size, 1>{}([&](auto I) { new_array(I) = old_array[ExtractSeq::Get(I)]; });
|
||||
|
||||
return new_array;
|
||||
}
|
||||
|
||||
template <class F, class X, class Y, class Z> // emulate constepxr lambda for array math
|
||||
struct lambda_array_math
|
||||
{
|
||||
const F& f;
|
||||
const X& x;
|
||||
const Y& y;
|
||||
Z& z;
|
||||
|
||||
__host__ __device__ constexpr lambda_array_math(const F& f_, const X& x_, const Y& y_, Z& z_)
|
||||
: f(f_), x(x_), y(y_), z(z_)
|
||||
{
|
||||
}
|
||||
|
||||
template <index_t IDim_>
|
||||
__host__ __device__ constexpr void operator()(Number<IDim_>) const
|
||||
{
|
||||
constexpr auto IDim = Number<IDim_>{};
|
||||
|
||||
z.Set(IDim, f(x[IDim], y[IDim]));
|
||||
}
|
||||
};
|
||||
|
||||
// Array = Array + Array
|
||||
template <class TData, index_t NSize>
|
||||
__host__ __device__ constexpr auto operator+(Array<TData, NSize> a, Array<TData, NSize> b)
|
||||
{
|
||||
Array<TData, NSize> result;
|
||||
|
||||
static_for<0, NSize, 1>{}([&](auto I) {
|
||||
constexpr index_t i = I.Get();
|
||||
auto f = mod_conv::plus<index_t>{};
|
||||
|
||||
result(i) = a[i] + b[i];
|
||||
});
|
||||
static_for<0, NSize, 1>{}(
|
||||
lambda_array_math<decltype(f), decltype(a), decltype(b), decltype(result)>(
|
||||
f, a, b, result));
|
||||
|
||||
return result;
|
||||
}
|
||||
@@ -170,11 +195,11 @@ __host__ __device__ constexpr auto operator-(Array<TData, NSize> a, Array<TData,
|
||||
{
|
||||
Array<TData, NSize> result;
|
||||
|
||||
static_for<0, NSize, 1>{}([&](auto I) {
|
||||
constexpr index_t i = I.Get();
|
||||
auto f = mod_conv::minus<index_t>{};
|
||||
|
||||
result(i) = a[i] - b[i];
|
||||
});
|
||||
static_for<0, NSize, 1>{}(
|
||||
lambda_array_math<decltype(f), decltype(a), decltype(b), decltype(result)>(
|
||||
f, a, b, result));
|
||||
|
||||
return result;
|
||||
}
|
||||
@@ -187,11 +212,11 @@ __host__ __device__ constexpr auto operator+(Array<TData, NSize> a, Sequence<Is.
|
||||
|
||||
Array<TData, NSize> result;
|
||||
|
||||
static_for<0, NSize, 1>{}([&](auto I) {
|
||||
constexpr index_t i = I.Get();
|
||||
auto f = mod_conv::plus<index_t>{};
|
||||
|
||||
result(i) = a[i] + b.Get(I);
|
||||
});
|
||||
static_for<0, NSize, 1>{}(
|
||||
lambda_array_math<decltype(f), decltype(a), decltype(b), decltype(result)>(
|
||||
f, a, b, result));
|
||||
|
||||
return result;
|
||||
}
|
||||
@@ -204,11 +229,11 @@ __host__ __device__ constexpr auto operator-(Array<TData, NSize> a, Sequence<Is.
|
||||
|
||||
Array<TData, NSize> result;
|
||||
|
||||
static_for<0, NSize, 1>{}([&](auto I) {
|
||||
constexpr index_t i = I.Get();
|
||||
auto f = mod_conv::minus<index_t>{};
|
||||
|
||||
result(i) = a[i] - b.Get(I);
|
||||
});
|
||||
static_for<0, NSize, 1>{}(
|
||||
lambda_array_math<decltype(f), decltype(a), decltype(b), decltype(result)>(
|
||||
f, a, b, result));
|
||||
|
||||
return result;
|
||||
}
|
||||
@@ -221,11 +246,11 @@ __host__ __device__ constexpr auto operator*(Array<TData, NSize> a, Sequence<Is.
|
||||
|
||||
Array<TData, NSize> result;
|
||||
|
||||
static_for<0, NSize, 1>{}([&](auto I) {
|
||||
constexpr index_t i = I.Get();
|
||||
auto f = mod_conv::multiplies<index_t>{};
|
||||
|
||||
result(i) = a[i] * b.Get(I);
|
||||
});
|
||||
static_for<0, NSize, 1>{}(
|
||||
lambda_array_math<decltype(f), decltype(a), decltype(b), decltype(result)>(
|
||||
f, a, b, result));
|
||||
|
||||
return result;
|
||||
}
|
||||
@@ -238,11 +263,11 @@ __host__ __device__ constexpr auto operator-(Sequence<Is...> a, Array<TData, NSi
|
||||
|
||||
Array<TData, NSize> result;
|
||||
|
||||
static_for<0, NSize, 1>{}([&](auto I) {
|
||||
constexpr index_t i = I.Get();
|
||||
auto f = mod_conv::minus<index_t>{};
|
||||
|
||||
result(i) = a.Get(I) - b[i];
|
||||
});
|
||||
static_for<0, NSize, 1>{}(
|
||||
lambda_array_math<decltype(f), decltype(a), decltype(b), decltype(result)>(
|
||||
f, a, b, result));
|
||||
|
||||
return result;
|
||||
}
|
||||
@@ -255,10 +280,7 @@ accumulate_on_array(const Array<TData, NSize>& a, Reduce f, TData init)
|
||||
|
||||
static_assert(NSize > 0, "wrong");
|
||||
|
||||
static_for<0, NSize, 1>{}([&](auto I) {
|
||||
constexpr index_t i = I.Get();
|
||||
result = f(result, a[i]);
|
||||
});
|
||||
static_for<0, NSize, 1>{}([&](auto I) { result = f(result, a[I]); });
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
@@ -48,13 +48,13 @@ struct ConstantTensorDescriptor
|
||||
template <index_t I>
|
||||
__host__ __device__ static constexpr index_t GetLength(Number<I>)
|
||||
{
|
||||
return Lengths{}.Get(Number<I>{});
|
||||
return Lengths::Get(Number<I>{});
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ static constexpr index_t GetStride(Number<I>)
|
||||
{
|
||||
return Strides{}.Get(Number<I>{});
|
||||
return Strides::Get(Number<I>{});
|
||||
}
|
||||
|
||||
struct lambda_AreDimensionsContinuous
|
||||
@@ -131,7 +131,7 @@ struct ConstantTensorDescriptor
|
||||
template <class X>
|
||||
__host__ __device__ constexpr void operator()(X IDim) const
|
||||
{
|
||||
offset += multi_id.Get(IDim) * Type::GetStride(IDim);
|
||||
offset += multi_id[IDim] * Type::GetStride(IDim);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -2,6 +2,9 @@
|
||||
#include "integral_constant.hip.hpp"
|
||||
#include "functional.hip.hpp"
|
||||
|
||||
template <class Seq>
|
||||
struct is_valid_sequence_map;
|
||||
|
||||
template <index_t... Is>
|
||||
struct Sequence
|
||||
{
|
||||
@@ -40,27 +43,24 @@ struct Sequence
|
||||
template <index_t... IRs>
|
||||
__host__ __device__ static constexpr auto ReorderGivenNew2Old(Sequence<IRs...> /*new2old*/)
|
||||
{
|
||||
#if 0 // require sequence_sort, which is not implemented yet
|
||||
static_assert(is_same<sequence_sort<Sequence<IRs...>>::SortedSeqType,
|
||||
arithmetic_sequence_gen<0, mSize, 1>::SeqType>::value,
|
||||
"wrong! invalid new2old map");
|
||||
#endif
|
||||
|
||||
static_assert(sizeof...(Is) == sizeof...(IRs),
|
||||
"wrong! new2old map should have the same size as Sequence to be rerodered");
|
||||
"wrong! reorder map should have the same size as Sequence to be rerodered");
|
||||
|
||||
return Sequence<Type{}.Get(Number<IRs>{})...>{};
|
||||
static_assert(is_valid_sequence_map<Sequence<IRs...>>::value, "wrong! invalid reorder map");
|
||||
|
||||
return Sequence<Type::Get(Number<IRs>{})...>{};
|
||||
}
|
||||
|
||||
#if 0 // require sequence_sort, which is not implemented yet
|
||||
template <class MapOld2New>
|
||||
__host__ __device__ static constexpr auto ReorderGivenOld2New(MapOld2New /*old2new*/)
|
||||
{
|
||||
#if 0
|
||||
static_assert(is_same<sequence_sort<MapOld2New>::SortedSeqType,
|
||||
arithmetic_sequence_gen<0, mSize, 1>::SeqType>::value,
|
||||
"wrong! invalid old2new map");
|
||||
#endif
|
||||
static_assert(sizeof...(Is) == MapOld2New::GetSize(),
|
||||
"wrong! reorder map should have the same size as Sequence to be rerodered");
|
||||
|
||||
static_assert(is_valid_sequence_map<MapOld2New>::value,
|
||||
"wrong! invalid reorder map");
|
||||
|
||||
constexpr auto map_new2old = typename sequence_map_inverse<MapOld2New>::SeqMapType{};
|
||||
|
||||
return ReorderGivenNew2Old(map_new2old);
|
||||
@@ -106,13 +106,13 @@ struct Sequence
|
||||
template <index_t... Ns>
|
||||
__host__ __device__ static constexpr auto Extract(Number<Ns>...)
|
||||
{
|
||||
return Sequence<Type{}.Get(Number<Ns>{})...>{};
|
||||
return Sequence<Type::Get(Number<Ns>{})...>{};
|
||||
}
|
||||
|
||||
template <index_t... Ns>
|
||||
__host__ __device__ static constexpr auto Extract(Sequence<Ns...>)
|
||||
{
|
||||
return Sequence<Type{}.Get(Number<Ns>{})...>{};
|
||||
return Sequence<Type::Get(Number<Ns>{})...>{};
|
||||
}
|
||||
|
||||
template <index_t I, index_t X>
|
||||
@@ -316,6 +316,7 @@ struct sequence_map_inverse<Sequence<Is...>>
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
template <class Seq>
|
||||
struct is_valid_sequence_map
|
||||
{
|
||||
|
||||
113
src/include/base.hip.hpp
Normal file
113
src/include/base.hip.hpp
Normal file
@@ -0,0 +1,113 @@
|
||||
#pragma once
|
||||
|
||||
__device__ index_t get_thread_local_1d_id() { return threadIdx.x; }
|
||||
|
||||
__device__ index_t get_block_1d_id() { return blockIdx.x; }
|
||||
|
||||
template <class T1, class T2>
|
||||
struct is_same
|
||||
{
|
||||
static constexpr bool value = false;
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct is_same<T, T>
|
||||
{
|
||||
static constexpr bool value = true;
|
||||
};
|
||||
|
||||
template <class X, class Y>
|
||||
__host__ __device__ constexpr bool is_same_type(X, Y)
|
||||
{
|
||||
return is_same<X, Y>::value;
|
||||
}
|
||||
|
||||
namespace mod_conv { // namespace mod_conv
|
||||
template <class T, T s>
|
||||
struct scales
|
||||
{
|
||||
__host__ __device__ constexpr T operator()(T a) const { return s * a; }
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct plus
|
||||
{
|
||||
__host__ __device__ constexpr T operator()(T a, T b) const { return a + b; }
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct minus
|
||||
{
|
||||
__host__ __device__ constexpr T operator()(T a, T b) const { return a - b; }
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct multiplies
|
||||
{
|
||||
__host__ __device__ constexpr T operator()(T a, T b) const { return a * b; }
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct integer_divide_ceiler
|
||||
{
|
||||
__host__ __device__ constexpr T operator()(T a, T b) const
|
||||
{
|
||||
static_assert(is_same<T, index_t>::value || is_same<T, int>::value, "wrong type");
|
||||
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
};
|
||||
|
||||
template <class T>
|
||||
__host__ __device__ constexpr T integer_divide_ceil(T a, T b)
|
||||
{
|
||||
static_assert(is_same<T, index_t>::value || is_same<T, int>::value, "wrong type");
|
||||
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
__host__ __device__ constexpr T max(T x, T y)
|
||||
{
|
||||
return x > y ? x : y;
|
||||
}
|
||||
|
||||
template <class T, class... Ts>
|
||||
__host__ __device__ constexpr T max(T x, Ts... xs)
|
||||
{
|
||||
static_assert(sizeof...(xs) > 0, "not enough argument");
|
||||
|
||||
auto y = max(xs...);
|
||||
|
||||
static_assert(is_same<decltype(y), T>::value, "not the same type");
|
||||
|
||||
return x > y ? x : y;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
__host__ __device__ constexpr T min(T x, T y)
|
||||
{
|
||||
return x < y ? x : y;
|
||||
}
|
||||
|
||||
template <class T, class... Ts>
|
||||
__host__ __device__ constexpr T min(T x, Ts... xs)
|
||||
{
|
||||
static_assert(sizeof...(xs) > 0, "not enough argument");
|
||||
|
||||
auto y = min(xs...);
|
||||
|
||||
static_assert(is_same<decltype(y), T>::value, "not the same type");
|
||||
|
||||
return x < y ? x : y;
|
||||
}
|
||||
|
||||
// this is wrong
|
||||
// TODO: implement correct least common multiple, instead of calling max()
|
||||
template <class T, class... Ts>
|
||||
__host__ __device__ constexpr T lcm(T x, Ts... xs)
|
||||
{
|
||||
return max(x, xs...);
|
||||
}
|
||||
|
||||
} // namespace mod_conv
|
||||
@@ -203,20 +203,18 @@ struct BlockwiseGenericTensorSliceCopy_v1
|
||||
make_ConstantTensorDescriptor_packed(thread_sub_tensor_lengths * repeat_lengths);
|
||||
|
||||
static_ford<decltype(repeat_lengths)>{}([&](auto repeat_multi_id_) {
|
||||
#if 0
|
||||
#if 1
|
||||
constexpr auto repeat_multi_id = sequence2array(decltype(repeat_multi_id_){});
|
||||
|
||||
const auto src_thread_data_multi_id_begin =
|
||||
repeat_multi_id * data_per_cluster_per_dims; // cannot not constexpr, why?
|
||||
const auto src_thread_data_multi_id_begin = repeat_multi_id * data_per_cluster_per_dims;
|
||||
|
||||
const auto clipboard_data_multi_id_begin =
|
||||
repeat_multi_id * thread_sub_tensor_lengths; // cannot not constexpr, why?
|
||||
const auto clipboard_data_multi_id_begin = repeat_multi_id * thread_sub_tensor_lengths;
|
||||
|
||||
const index_t src_offset = SrcDesc{}.GetOffsetFromMultiIndex(
|
||||
src_thread_data_multi_id_begin); // cannot not constexpr, why?
|
||||
const index_t src_offset =
|
||||
SrcDesc{}.GetOffsetFromMultiIndex(src_thread_data_multi_id_begin);
|
||||
|
||||
const index_t clipboard_offset = thread_tensor_desc.GetOffsetFromMultiIndex(
|
||||
clipboard_data_multi_id_begin); // cannot not constexpr, why?
|
||||
const index_t clipboard_offset =
|
||||
thread_tensor_desc.GetOffsetFromMultiIndex(clipboard_data_multi_id_begin);
|
||||
#else
|
||||
constexpr auto repeat_multi_id = decltype(repeat_multi_id_){};
|
||||
|
||||
@@ -258,20 +256,17 @@ struct BlockwiseGenericTensorSliceCopy_v1
|
||||
make_ConstantTensorDescriptor_packed(thread_sub_tensor_lengths * repeat_lengths);
|
||||
|
||||
static_ford<decltype(repeat_lengths)>{}([&](auto repeat_multi_id_) {
|
||||
#if 0
|
||||
#if 1
|
||||
constexpr auto repeat_multi_id = sequence2array(decltype(repeat_multi_id_){});
|
||||
|
||||
const auto clipboard_data_multi_id_begin =
|
||||
repeat_multi_id * thread_sub_tensor_lengths; // cannot not constexpr, why?
|
||||
const auto clipboard_data_multi_id_begin = repeat_multi_id * thread_sub_tensor_lengths;
|
||||
|
||||
const auto dst_data_multi_id_begin =
|
||||
repeat_multi_id * data_per_cluster_per_dims; // cannot not constexpr, why?
|
||||
const auto dst_data_multi_id_begin = repeat_multi_id * data_per_cluster_per_dims;
|
||||
|
||||
const index_t clipboard_offset = thread_tensor_desc.GetOffsetFromMultiIndex(
|
||||
clipboard_data_multi_id_begin); // cannot not constexpr, why?
|
||||
const index_t clipboard_offset =
|
||||
thread_tensor_desc.GetOffsetFromMultiIndex(clipboard_data_multi_id_begin);
|
||||
|
||||
const index_t dst_offset = DstDesc{}.GetOffsetFromMultiIndex(
|
||||
dst_data_multi_id_begin); // cannot not constexpr, why?
|
||||
const index_t dst_offset = DstDesc{}.GetOffsetFromMultiIndex(dst_data_multi_id_begin);
|
||||
#else
|
||||
constexpr auto repeat_multi_id = decltype(repeat_multi_id_){};
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
#pragma once
|
||||
#include "base.hip.hpp"
|
||||
#include "vector_type.hip.hpp"
|
||||
#include "integral_constant.hip.hpp"
|
||||
#include "Sequence.hip.hpp"
|
||||
@@ -10,109 +11,3 @@
|
||||
#if USE_AMD_INLINE_ASM
|
||||
#include "amd_inline_asm.hip.hpp"
|
||||
#endif
|
||||
|
||||
__device__ index_t get_thread_local_1d_id() { return threadIdx.x; }
|
||||
|
||||
__device__ index_t get_block_1d_id() { return blockIdx.x; }
|
||||
|
||||
template <class T1, class T2>
|
||||
struct is_same
|
||||
{
|
||||
static constexpr bool value = false;
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct is_same<T, T>
|
||||
{
|
||||
static constexpr bool value = true;
|
||||
};
|
||||
|
||||
template <class X, class Y>
|
||||
__host__ __device__ constexpr bool is_same_type(X, Y)
|
||||
{
|
||||
return is_same<X, Y>::value;
|
||||
}
|
||||
|
||||
namespace mod_conv { // namespace mod_conv
|
||||
template <class T, T s>
|
||||
struct scales
|
||||
{
|
||||
__host__ __device__ constexpr T operator()(T a) const { return s * a; }
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct plus
|
||||
{
|
||||
__host__ __device__ constexpr T operator()(T a, T b) const { return a + b; }
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct multiplies
|
||||
{
|
||||
__host__ __device__ constexpr T operator()(T a, T b) const { return a * b; }
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct integer_divide_ceiler
|
||||
{
|
||||
__host__ __device__ constexpr T operator()(T a, T b) const
|
||||
{
|
||||
static_assert(is_same<T, index_t>::value || is_same<T, int>::value, "wrong type");
|
||||
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
};
|
||||
|
||||
template <class T>
|
||||
__host__ __device__ constexpr T integer_divide_ceil(T a, T b)
|
||||
{
|
||||
static_assert(is_same<T, index_t>::value || is_same<T, int>::value, "wrong type");
|
||||
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
__host__ __device__ constexpr T max(T x, T y)
|
||||
{
|
||||
return x > y ? x : y;
|
||||
}
|
||||
|
||||
template <class T, class... Ts>
|
||||
__host__ __device__ constexpr T max(T x, Ts... xs)
|
||||
{
|
||||
static_assert(sizeof...(xs) > 0, "not enough argument");
|
||||
|
||||
auto y = max(xs...);
|
||||
|
||||
static_assert(is_same<decltype(y), T>::value, "not the same type");
|
||||
|
||||
return x > y ? x : y;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
__host__ __device__ constexpr T min(T x, T y)
|
||||
{
|
||||
return x < y ? x : y;
|
||||
}
|
||||
|
||||
template <class T, class... Ts>
|
||||
__host__ __device__ constexpr T min(T x, Ts... xs)
|
||||
{
|
||||
static_assert(sizeof...(xs) > 0, "not enough argument");
|
||||
|
||||
auto y = min(xs...);
|
||||
|
||||
static_assert(is_same<decltype(y), T>::value, "not the same type");
|
||||
|
||||
return x < y ? x : y;
|
||||
}
|
||||
|
||||
// this is wrong
|
||||
// TODO: implement correct least common multiple, instead of calling max()
|
||||
template <class T, class... Ts>
|
||||
__host__ __device__ constexpr T lcm(T x, Ts... xs)
|
||||
{
|
||||
return max(x, xs...);
|
||||
}
|
||||
|
||||
} // namespace mod_conv
|
||||
|
||||
@@ -11,7 +11,7 @@ struct static_ford_impl
|
||||
// F signature: F(Sequence<...> multi_id)
|
||||
// CurrentMultiIndex: Sequence<...>
|
||||
template <class F, class CurrentMultiIndex>
|
||||
__host__ __device__ void operator()(F f, CurrentMultiIndex) const
|
||||
__host__ __device__ constexpr void operator()(F f, CurrentMultiIndex) const
|
||||
{
|
||||
static_assert(RemainLengths::GetSize() > 0, "wrong! should not get here");
|
||||
|
||||
@@ -28,7 +28,7 @@ struct static_ford_impl<Sequence<>>
|
||||
// F signature: F(Sequence<...> multi_id)
|
||||
// CurrentMultiIndex: Sequence<...>
|
||||
template <class F, class CurrentMultiIndex>
|
||||
__host__ __device__ void operator()(F f, CurrentMultiIndex) const
|
||||
__host__ __device__ constexpr void operator()(F f, CurrentMultiIndex) const
|
||||
{
|
||||
f(CurrentMultiIndex{});
|
||||
}
|
||||
@@ -40,7 +40,7 @@ struct static_ford
|
||||
{
|
||||
// F signature: F(Sequence<...> multi_id)
|
||||
template <class F>
|
||||
__host__ __device__ void operator()(F f) const
|
||||
__host__ __device__ constexpr void operator()(F f) const
|
||||
{
|
||||
static_assert(Lengths::GetSize() > 0, "wrong! Lengths is empty");
|
||||
|
||||
@@ -55,7 +55,7 @@ struct ford_impl
|
||||
// CurrentMultiIndex: Array<...>
|
||||
// RemainLengths: Sequence<...>
|
||||
template <class F, class CurrentMultiIndex, class RemainLengths>
|
||||
__host__ __device__ void
|
||||
__host__ __device__ constexpr void
|
||||
operator()(F f, CurrentMultiIndex current_multi_id, RemainLengths) const
|
||||
{
|
||||
static_assert(RemainLengths::GetSize() == RemainDim, "wrong!");
|
||||
@@ -77,7 +77,7 @@ struct ford_impl<1>
|
||||
// CurrentMultiIndex: Array<...>
|
||||
// RemainLengths: Sequence<...>
|
||||
template <class F, class CurrentMultiIndex, class RemainLengths>
|
||||
__host__ __device__ void
|
||||
__host__ __device__ constexpr void
|
||||
operator()(F f, CurrentMultiIndex current_multi_id, RemainLengths) const
|
||||
{
|
||||
static_assert(RemainLengths::GetSize() == 1, "wrong!");
|
||||
@@ -97,7 +97,7 @@ struct ford
|
||||
{
|
||||
// F signature: F(Array<...> multi_id)
|
||||
template <class F>
|
||||
__host__ __device__ void operator()(F f) const
|
||||
__host__ __device__ constexpr void operator()(F f) const
|
||||
{
|
||||
constexpr index_t first_length = Lengths{}.Front();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user