refactoring ConstantTensorDescriptor

[ROCm/composable_kernel commit: a0584426ff]
This commit is contained in:
Chao Liu
2019-03-17 03:22:41 -05:00
parent 0485704cb3
commit 6fd0910da8
9 changed files with 452 additions and 340 deletions

View File

@@ -0,0 +1,49 @@
#pragma once
#include "constant_integral.hip.hpp"
template <unsigned NLoop>
struct static_loop_n
{
template <class F>
__host__ __device__ void operator()(F f) const
{
static_assert(NLoop > 1, "out-of-range");
f(Number<NLoop - 1>{});
static_loop_n<NLoop - 1>{}(f);
}
};
template <>
struct static_loop_n<1>
{
template <class F>
__host__ __device__ void operator()(F f) const
{
f(Number<0>{});
}
};
template <unsigned NLoop>
struct static_const_reduce_n
{
template <class F, class Reduce>
__host__ __device__ constexpr auto operator()(F f, Reduce r) const
{
static_assert(NLoop > 1, "out-of-range");
constexpr auto a = f(Number<NLoop - 1>{});
auto b = static_const_reduce_n<NLoop - 1>{}(f, r); // cannot use constexpr here, weird
return r(a, b);
}
};
template <>
struct static_const_reduce_n<1>
{
template <class F, class Reduce>
__host__ __device__ constexpr auto operator()(F f, Reduce) const
{
return f(Number<0>{});
}
};