This commit is contained in:
Chao Liu
2019-05-31 22:28:32 -05:00
parent 8d4607403e
commit 97ba755f2f
22 changed files with 168 additions and 543 deletions

View File

@@ -2,25 +2,23 @@
#include "common.hip.hpp"
template <class Lengths>
__host__ __device__ constexpr auto calculate_tensor_strides_default_rank_packed(Lengths)
__host__ __device__ constexpr auto calculate_tensor_strides_packed(Lengths)
{
return reverse_inclusive_scan_sequence(Lengths{}.PopFront(), mod_conv::multiplies<index_t>{})
.PushBack(Number<1>{});
}
template <class Lengths, index_t Align>
__host__ __device__ constexpr auto calculate_tensor_strides_default_rank_aligned(Lengths,
Number<Align>)
__host__ __device__ constexpr auto calculate_tensor_strides_aligned(Lengths, Number<Align>)
{
constexpr index_t L_back_align =
Align * mod_conv::integer_divide_ceiler<index_t>{}(Lengths{}.Back(), Align);
return calculate_tensor_strides_default_rank_packed(
return calculate_tensor_strides_packed(
Lengths{}.Modify(Number<Lengths{}.GetSize() - 1>{}, Number<L_back_align>{}));
}
// MemoryRanks of dimensions is for conversion from offset to multi-index
template <class Lengths, class Strides, class MemoryRanks>
template <class Lengths, class Strides>
struct ConstantTensorDescriptor
{
using Type = ConstantTensorDescriptor;
@@ -29,15 +27,7 @@ struct ConstantTensorDescriptor
__host__ __device__ constexpr ConstantTensorDescriptor()
{
static_assert(Lengths::GetSize() == Strides::GetSize() &&
Lengths::GetSize() == MemoryRanks::GetSize(),
"nDim not consistent");
#if 0 // require sequence_sort, but it's not implemented yet
static_assert(is_same<typename sequence_sort<MemoryRanks>::SortedSeqType,
typename arithmetic_sequence_gen<0, nDim, 1>::SeqType>::value,
"wrong! invalid MemoryRanks");
#endif
static_assert(Lengths::GetSize() == Strides::GetSize(), "nDim not consistent");
}
__host__ __device__ static constexpr auto GetOriginalTensorDescriptor() { return Type{}; }
@@ -54,8 +44,6 @@ struct ConstantTensorDescriptor
__host__ __device__ static constexpr auto GetStrides() { return Strides{}; }
__host__ __device__ static constexpr auto GetMemoryRanks() { return MemoryRanks{}; }
template <index_t I>
__host__ __device__ static constexpr index_t GetLength(Number<I>)
{
@@ -68,12 +56,6 @@ struct ConstantTensorDescriptor
return Strides{}.Get(Number<I>{});
}
template <index_t I>
__host__ __device__ static constexpr index_t GetMemoryRank(Number<I>)
{
return MemoryRanks{}.Get(Number<I>{});
}
__host__ __device__ static constexpr bool AreStridesNonAscending()
{
bool flag = true;
@@ -98,20 +80,13 @@ struct ConstantTensorDescriptor
return accumulate_on_sequence(Lengths{}, mod_conv::multiplies<index_t>{}, Number<1>{});
}
// WRONG! ReorderGivenOld2New is broken
template <class Align = Number<1>>
__host__ __device__ static constexpr index_t GetElementSpace(Align align = Align{})
{
#if 0
constexpr auto lengths_in_rank = GetLengths().ReorderGivenOld2New(MemoryRank{});
constexpr auto strides_in_rank = GetStrides().ReorderGivenOld2new(MemoryRank{});
constexpr index_t element_space_unaligned = accumulate_on_sequence(
(lengths_in_rank - Number<1>{}) * strides_in_rank, mod_conv::plus<index_t>{}, Number<1>{});
#else // WRONG! align shouldbe applied to the last memory rank, not the last tensor dimension
// This is WRONG! align shouldbe applied to the last memory rank, not the last tensor
// dimension
constexpr index_t element_space_unaligned = accumulate_on_sequence(
(GetLengths() - Number<1>{}) * GetStrides(), mod_conv::plus<index_t>{}, Number<1>{});
#endif
return align.Get() * ((element_space_unaligned + align.Get() - 1) / align.Get());
}
@@ -148,35 +123,13 @@ struct ConstantTensorDescriptor
multi_id * GetStrides(), mod_conv::plus<index_t>{}, Number<0>{});
}
#if 0 // ReorderGivenOld2new is broken
__host__ __device__ static Array<index_t, nDim> GetMultiIndexFromOffset(index_t offset)
{
Array<index_t, nDim> ranked_multi_id;
constexpr auto ranked_strides =
GetStrides().ReorderGivenOld2new(MemoryRanks{}); // check this
// calculate index in each of the dimensions in the order of their rank (not dimension)
static_for<0, nDim - 1, 1>{}([&](auto IDim) {
constexpr index_t idim = IDim.Get();
constexpr index_t stride = ranked_strides.Get(Number<idim>{});
ranked_multi_id[idim] = offset / stride;
offset -= ranked_multi_id[idim] * stride;
});
ranked_multi_id[nDim - 1] = offset / ranked_strides.Get(Number<nDim - 1>{});
return reorder_array_given_new2old(ranked_multi_id, MemoryRanks{}); // check this
}
#endif
__host__ __device__ static Array<index_t, nDim> GetMultiIndexFrom1dIndex(index_t id)
{
Array<index_t, nDim> multi_id;
constexpr auto dummy_strides = calculate_tensor_strides_default_rank_packed(GetLengths());
constexpr auto dummy_strides = calculate_tensor_strides_packed(GetLengths());
// calculate index in each of the dimensions in the order of their dimension (not rank)
// calculate index in each of the dimensions in the order of their dimension
static_for<0, nDim - 1, 1>{}([&](auto IDim) {
constexpr index_t idim = IDim.Get();
constexpr index_t stride = dummy_strides.Get(Number<idim>{});
@@ -267,24 +220,16 @@ struct ConstantTensorDescriptor
return new_multi_id;
}
// WRONG! Ranks is broken
template <index_t... IDims>
__host__ __device__ static constexpr auto Extract(Number<IDims>... extract_dims)
{
static_assert(sizeof...(IDims) <= GetNumOfDimension(),
"wrong! too many number of dimensions to be extracted");
using extract_lengths = decltype(Lengths{}.Extract(extract_dims...));
using extract_strides = decltype(Strides{}.Extract(extract_dims...));
using extract_ranks = decltype(MemoryRanks{}.Extract(extract_dims...));
using extract_lengths = decltype(Lengths::Extract(extract_dims...));
using extract_strides = decltype(Strides::Extract(extract_dims...));
#if 0
using new_ranks = typename sequence_sort<extract_ranks>::Original2SortedType;
#else // WRONG! TODO:: implement sequence_sort
using new_ranks = typename arithmetic_sequence_gen<0, sizeof...(IDims), 1>::SeqType;
#endif
return ConstantTensorDescriptor<extract_lengths, extract_strides, new_ranks>{};
return ConstantTensorDescriptor<extract_lengths, extract_strides>{};
}
template <index_t... IDims>
@@ -298,12 +243,8 @@ struct ConstantTensorDescriptor
{
using leaf_tensor = ConstantTensorDescriptor<Ts...>;
// memory rank is broken
// TODO: remove memory rank info from tensor descritpor
return ConstantTensorDescriptor<decltype(GetLengths().Append(leaf_tensor::GetLengths())),
decltype(GetStrides().Append(leaf_tensor::GetStrides())),
decltype(GetMemoryRanks().Append(
leaf_tensor::GetMemoryRanks()))>{};
decltype(GetStrides().Append(leaf_tensor::GetStrides()))>{};
}
template <index_t IDim, index_t SliceLen>
@@ -311,18 +252,9 @@ struct ConstantTensorDescriptor
{
using slice_lengths = decltype(Lengths{}.Modify(Number<IDim>{}, Number<SliceLen>{}));
return ConstantTensorDescriptor<slice_lengths, Strides, MemoryRanks>{};
return ConstantTensorDescriptor<slice_lengths, Strides>{};
}
template <index_t Threashold, index_t Delta>
struct f_fold_impl
{
__host__ __device__ constexpr index_t operator()(index_t x) const
{
return x > Threashold ? x + Delta : x;
}
};
template <index_t IDim, index_t... FoldIntervals>
__host__ __device__ static constexpr auto Fold(Number<IDim>, Number<FoldIntervals>...)
{
@@ -333,7 +265,6 @@ struct ConstantTensorDescriptor
constexpr auto unfold_length = GetLength(Number<IDim>{});
constexpr auto unfold_stride = GetStride(Number<IDim>{});
constexpr auto unfold_rank = GetMemoryRank(Number<IDim>{});
// length of the dimension to be folded needs to be dividable by fold_interval_product,
// otherwise, folding is invalid
@@ -350,16 +281,6 @@ struct ConstantTensorDescriptor
reverse_inclusive_scan_sequence(fold_intervals.PushBack(Number<1>{}),
mod_conv::multiplies<index_t>{});
// folded_ranks
constexpr auto fold_ranks =
typename arithmetic_sequence_gen<unfold_rank,
unfold_rank + fold_intervals.GetSize() + 1,
1>::SeqType{};
// increase the ranks that are larger than unfold_rank
constexpr auto tmp_ranks = transform_sequences(
f_fold_impl<unfold_rank, fold_intervals.GetSize()>{}, GetMemoryRanks());
// left and right
constexpr auto left = typename arithmetic_sequence_gen<0, IDim, 1>::SeqType{};
constexpr auto right =
@@ -369,15 +290,8 @@ struct ConstantTensorDescriptor
GetLengths().Extract(left).Append(fold_lengths).Append(GetLengths().Extract(right));
constexpr auto new_strides =
GetStrides().Extract(left).Append(fold_strides).Append(GetStrides().Extract(right));
constexpr auto new_ranks =
tmp_ranks.Extract(left).Append(fold_ranks).Append(tmp_ranks.Extract(right));
static_assert(new_ranks.GetSize() == new_lengths.GetSize(), "wrong!");
static_assert(fold_ranks.GetSize() == fold_lengths.GetSize(), "wrong!");
return ConstantTensorDescriptor<decltype(new_lengths),
decltype(new_strides),
decltype(new_ranks)>{};
return ConstantTensorDescriptor<decltype(new_lengths), decltype(new_strides)>{};
}
template <index_t Threashold, index_t Delta>
@@ -411,11 +325,6 @@ struct ConstantTensorDescriptor
// check if packed
static_assert(GetStride(IDim_p1) * GetLength(IDim_p1) == GetStride(IDim),
"wrong! dimensions to be unfolded need to be packed");
// check ranks
static_assert(GetMemoryRank(IDim_p1) == GetMemoryRank(IDim) + 1,
"wrong! ranks of dimensions to be unfolded need to be in increasing and "
"continuous ranks");
});
#endif
@@ -426,21 +335,13 @@ struct ConstantTensorDescriptor
constexpr auto right =
typename arithmetic_sequence_gen<LastUnfoldDim + 1, GetNumOfDimension(), 1>::SeqType{};
// unfolded length, stride and rank
// unfolded length, stride
constexpr index_t unfold_length = accumulate_on_sequence(
GetLengths().Extract(middle), mod_conv::multiplies<index_t>{}, Number<1>{});
constexpr index_t unfold_stride = GetStride(Number<LastUnfoldDim>{});
constexpr index_t unfold_rank = GetMemoryRank(Number<FirstUnfoldDim>{});
// decrease the ranks that are larger than the rank of LastUnfoldDim
constexpr auto tmp_ranks =
transform_sequences(f_unfold_impl<GetMemoryRank(Number<LastUnfoldDim>{}),
LastUnfoldDim - FirstUnfoldDim + 1>{},
GetMemoryRanks());
// new lengths, strides and ranks
// new lengths, strides
constexpr auto new_lengths = GetLengths()
.Extract(left)
.PushBack(Number<unfold_length>{})
@@ -451,22 +352,14 @@ struct ConstantTensorDescriptor
.PushBack(Number<unfold_stride>{})
.Append(GetStrides().Extract(right));
constexpr auto new_ranks = tmp_ranks.Extract(left)
.PushBack(Number<unfold_rank>{})
.Append(tmp_ranks.Extract(right));
return ConstantTensorDescriptor<decltype(new_lengths),
decltype(new_strides),
decltype(new_ranks)>{};
return ConstantTensorDescriptor<decltype(new_lengths), decltype(new_strides)>{};
}
template <class MapNew2Old>
__host__ __device__ static constexpr auto ReorderGivenNew2Old(MapNew2Old)
{
return ConstantTensorDescriptor<decltype(Lengths{}.ReorderGivenNew2Old(MapNew2Old{})),
decltype(Strides{}.ReorderGivenNew2Old(MapNew2Old{})),
decltype(
MemoryRanks{}.ReorderGivenNew2Old(MapNew2Old{}))>{};
decltype(Strides{}.ReorderGivenNew2Old(MapNew2Old{}))>{};
}
#if 0 // require sequence_sort, which is not implemented yet
@@ -474,358 +367,108 @@ struct ConstantTensorDescriptor
__host__ __device__ static constexpr auto ReorderGivenOld2New(MapOld2New)
{
return ConstantTensorDescriptor<decltype(Lengths{}.ReorderGivenOld2New(MapOld2New{})),
decltype(Strides{}.ReorderGivenOld2New(MapOld2New{})),
decltype(
MemoryRanks{}.ReorderGivenOld2New(MapOld2New{}))>{};
decltype(Strides{}.ReorderGivenOld2New(MapOld2New{}))>{}
}
#endif
};
template <class Lengths>
__host__ __device__ constexpr auto make_ConstantTensorDescriptor_default_rank_packed(Lengths)
__host__ __device__ constexpr auto make_ConstantTensorDescriptor_packed(Lengths)
{
using Strides = decltype(calculate_tensor_strides_default_rank_packed(Lengths{}));
using MemoryRanks = typename arithmetic_sequence_gen<0, Lengths::GetSize(), 1>::SeqType;
return ConstantTensorDescriptor<Lengths, Strides, MemoryRanks>{};
using Strides = decltype(calculate_tensor_strides_packed(Lengths{}));
return ConstantTensorDescriptor<Lengths, Strides>{};
}
template <class Lengths, class Strides>
__host__ __device__ constexpr auto make_ConstantTensorDescriptor_default_rank(Lengths, Strides)
__host__ __device__ constexpr auto make_ConstantTensorDescriptor(Lengths, Strides)
{
using MemoryRanks = typename arithmetic_sequence_gen<0, Lengths::GetSize(), 1>::SeqType;
return ConstantTensorDescriptor<Lengths, Strides, MemoryRanks>{};
return ConstantTensorDescriptor<Lengths, Strides>{};
}
template <class Lengths, index_t Align>
__host__ __device__ constexpr auto make_ConstantTensorDescriptor_default_rank_aligned(Lengths,
Number<Align>)
__host__ __device__ constexpr auto make_ConstantTensorDescriptor_aligned(Lengths, Number<Align>)
{
using Strides =
decltype(calculate_tensor_strides_default_rank_aligned(Lengths{}, Number<Align>{}));
using MemoryRanks = typename arithmetic_sequence_gen<0, Lengths::GetSize(), 1>::SeqType;
return ConstantTensorDescriptor<Lengths, Strides, MemoryRanks>{};
using Strides = decltype(calculate_tensor_strides_aligned(Lengths{}, Number<Align>{}));
return ConstantTensorDescriptor<Lengths, Strides>{};
}
template <class TDesc>
__host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
template <index_t... Lengths, index_t... Strides>
__host__ __device__ void
print_ConstantTensorDescriptor(const char* s,
ConstantTensorDescriptor<Sequence<Lengths...>, Sequence<Strides...>>)
{
constexpr index_t ndim = TDesc::GetNumOfDimension();
constexpr index_t ndim = sizeof...(Lengths);
static_assert(ndim >= 1 && ndim <= 10, "wrong!");
static_assert(ndim > 0 && ndim <= 10, "wrong!");
static_if<ndim == 1>{}([&](auto fwd) {
constexpr auto I0 = Number<0>{};
constexpr auto desc = fwd(TDesc{});
printf("%s dim %u, lengths {%u}, strides {%u}, ranks {%u}\n",
s,
desc.GetNumOfDimension(),
desc.GetLength(I0),
desc.GetStride(I0),
desc.GetMemoryRank(I0));
static_if<ndim == 1>{}([&](auto) {
printf("%s dim %u, lengths {%u}, strides {%u}\n", s, ndim, Lengths..., Strides...);
});
static_if<ndim == 2>{}([&](auto fwd) {
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto desc = fwd(TDesc{});
printf("%s dim %u, lengths {%u %u}, strides {%u %u}, ranks {%u %u}\n",
s,
desc.GetNumOfDimension(),
desc.GetLength(I0),
desc.GetLength(I1),
desc.GetStride(I0),
desc.GetStride(I1),
desc.GetMemoryRank(I0),
desc.GetMemoryRank(I1));
static_if<ndim == 2>{}([&](auto) {
printf("%s dim %u, lengths {%u %u}, strides {%u %u}\n", s, ndim, Lengths..., Strides...);
});
static_if<ndim == 3>{}([&](auto fwd) {
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto desc = fwd(TDesc{});
printf("%s dim %u, lengths {%u %u %u}, strides {%u %u %u}, ranks {%u %u %u}\n",
s,
desc.GetNumOfDimension(),
desc.GetLength(I0),
desc.GetLength(I1),
desc.GetLength(I2),
desc.GetStride(I0),
desc.GetStride(I1),
desc.GetStride(I2),
desc.GetMemoryRank(I0),
desc.GetMemoryRank(I1),
desc.GetMemoryRank(I2));
static_if<ndim == 3>{}([&](auto) {
printf(
"%s dim %u, lengths {%u %u %u}, strides {%u %u %u}\n", s, ndim, Lengths..., Strides...);
});
static_if<ndim == 4>{}([&](auto fwd) {
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto desc = fwd(TDesc{});
printf("%s dim %u, lengths {%u %u %u %u}, strides {%u %u %u %u}, ranks {%u %u %u %u}\n",
static_if<ndim == 4>{}([&](auto) {
printf("%s dim %u, lengths {%u %u %u %u}, strides {%u %u %u %u}\n",
s,
desc.GetNumOfDimension(),
desc.GetLength(I0),
desc.GetLength(I1),
desc.GetLength(I2),
desc.GetLength(I3),
desc.GetStride(I0),
desc.GetStride(I1),
desc.GetStride(I2),
desc.GetStride(I3),
desc.GetMemoryRank(I0),
desc.GetMemoryRank(I1),
desc.GetMemoryRank(I2),
desc.GetMemoryRank(I3));
ndim,
Lengths...,
Strides...);
});
static_if<ndim == 5>{}([&](auto fwd) {
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
static_if<ndim == 5>{}([&](auto) {
printf("%s dim %u, lengths {%u %u %u %u %u}, strides {%u %u %u %u %u}\n",
s,
ndim,
Lengths...,
Strides...);
});
constexpr auto desc = fwd(TDesc{});
static_if<ndim == 6>{}([&](auto) {
printf("%s dim %u, lengths {%u %u %u %u %u %u}, strides {%u %u %u %u %u %u}\n",
s,
ndim,
Lengths...,
Strides...);
});
printf("%s dim %u, lengths {%u %u %u %u %u}, strides {%u %u %u %u %u}, ranks {%u %u %u %u "
static_if<ndim == 7>{}([&](auto) {
printf("%s dim %u, lengths {%u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u}\n",
s,
ndim,
Lengths...,
Strides...);
});
static_if<ndim == 8>{}([&](auto) {
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u}\n",
s,
ndim,
Lengths...,
Strides...);
});
static_if<ndim == 9>{}([&](auto) {
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u "
"%u}\n",
s,
desc.GetNumOfDimension(),
desc.GetLength(I0),
desc.GetLength(I1),
desc.GetLength(I2),
desc.GetLength(I3),
desc.GetLength(I4),
desc.GetStride(I0),
desc.GetStride(I1),
desc.GetStride(I2),
desc.GetStride(I3),
desc.GetStride(I4),
desc.GetMemoryRank(I0),
desc.GetMemoryRank(I1),
desc.GetMemoryRank(I2),
desc.GetMemoryRank(I3),
desc.GetMemoryRank(I4));
ndim,
Lengths...,
Strides...);
});
static_if<ndim == 6>{}([&](auto fwd) {
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
constexpr auto I5 = Number<5>{};
constexpr auto desc = fwd(TDesc{});
printf("%s dim %u, lengths {%u %u %u %u %u %u}, strides {%u %u %u %u %u %u}, ranks {%u %u "
"%u %u %u %u}\n",
s,
desc.GetNumOfDimension(),
desc.GetLength(I0),
desc.GetLength(I1),
desc.GetLength(I2),
desc.GetLength(I3),
desc.GetLength(I4),
desc.GetLength(I5),
desc.GetStride(I0),
desc.GetStride(I1),
desc.GetStride(I2),
desc.GetStride(I3),
desc.GetStride(I4),
desc.GetStride(I5),
desc.GetMemoryRank(I0),
desc.GetMemoryRank(I1),
desc.GetMemoryRank(I2),
desc.GetMemoryRank(I3),
desc.GetMemoryRank(I4),
desc.GetMemoryRank(I5));
});
static_if<ndim == 7>{}([&](auto fwd) {
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
constexpr auto I5 = Number<5>{};
constexpr auto I6 = Number<6>{};
constexpr auto desc = fwd(TDesc{});
printf("%s dim %u, lengths {%u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u}, ranks "
"{%u %u %u %u %u %u %u}\n",
s,
desc.GetNumOfDimension(),
desc.GetLength(I0),
desc.GetLength(I1),
desc.GetLength(I2),
desc.GetLength(I3),
desc.GetLength(I4),
desc.GetLength(I5),
desc.GetLength(I6),
desc.GetStride(I0),
desc.GetStride(I1),
desc.GetStride(I2),
desc.GetStride(I3),
desc.GetStride(I4),
desc.GetStride(I5),
desc.GetStride(I6),
desc.GetMemoryRank(I0),
desc.GetMemoryRank(I1),
desc.GetMemoryRank(I2),
desc.GetMemoryRank(I3),
desc.GetMemoryRank(I4),
desc.GetMemoryRank(I5),
desc.GetMemoryRank(I6));
});
static_if<ndim == 8>{}([&](auto fwd) {
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
constexpr auto I5 = Number<5>{};
constexpr auto I6 = Number<6>{};
constexpr auto I7 = Number<7>{};
constexpr auto desc = fwd(TDesc{});
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u}, "
"ranks {%u %u %u %u %u %u %u %u}\n",
s,
desc.GetNumOfDimension(),
desc.GetLength(I0),
desc.GetLength(I1),
desc.GetLength(I2),
desc.GetLength(I3),
desc.GetLength(I4),
desc.GetLength(I5),
desc.GetLength(I6),
desc.GetLength(I7),
desc.GetStride(I0),
desc.GetStride(I1),
desc.GetStride(I2),
desc.GetStride(I3),
desc.GetStride(I4),
desc.GetStride(I5),
desc.GetStride(I6),
desc.GetStride(I7),
desc.GetMemoryRank(I0),
desc.GetMemoryRank(I1),
desc.GetMemoryRank(I2),
desc.GetMemoryRank(I3),
desc.GetMemoryRank(I4),
desc.GetMemoryRank(I5),
desc.GetMemoryRank(I6),
desc.GetMemoryRank(I7));
});
static_if<ndim == 9>{}([&](auto fwd) {
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
constexpr auto I5 = Number<5>{};
constexpr auto I6 = Number<6>{};
constexpr auto I7 = Number<7>{};
constexpr auto I8 = Number<8>{};
constexpr auto desc = fwd(TDesc{});
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u "
"%u}, ranks {%u %u %u %u %u %u %u %u %u}\n",
s,
desc.GetNumOfDimension(),
desc.GetLength(I0),
desc.GetLength(I1),
desc.GetLength(I2),
desc.GetLength(I3),
desc.GetLength(I4),
desc.GetLength(I5),
desc.GetLength(I6),
desc.GetLength(I7),
desc.GetLength(I8),
desc.GetStride(I0),
desc.GetStride(I1),
desc.GetStride(I2),
desc.GetStride(I3),
desc.GetStride(I4),
desc.GetStride(I5),
desc.GetStride(I6),
desc.GetStride(I7),
desc.GetStride(I8),
desc.GetMemoryRank(I0),
desc.GetMemoryRank(I1),
desc.GetMemoryRank(I2),
desc.GetMemoryRank(I3),
desc.GetMemoryRank(I4),
desc.GetMemoryRank(I5),
desc.GetMemoryRank(I6),
desc.GetMemoryRank(I7),
desc.GetMemoryRank(I8));
});
static_if<ndim == 10>{}([&](auto fwd) {
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
constexpr auto I5 = Number<5>{};
constexpr auto I6 = Number<6>{};
constexpr auto I7 = Number<7>{};
constexpr auto I8 = Number<8>{};
constexpr auto I9 = Number<9>{};
constexpr auto desc = fwd(TDesc{});
static_if<ndim == 10>{}([&](auto) {
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u "
"%u %u %u}, ranks {%u %u %u %u %u %u %u %u %u %u}\n",
"%u %u %u}\n",
s,
desc.GetNumOfDimension(),
desc.GetLength(I0),
desc.GetLength(I1),
desc.GetLength(I2),
desc.GetLength(I3),
desc.GetLength(I4),
desc.GetLength(I5),
desc.GetLength(I6),
desc.GetLength(I7),
desc.GetLength(I8),
desc.GetLength(I9),
desc.GetStride(I0),
desc.GetStride(I1),
desc.GetStride(I2),
desc.GetStride(I3),
desc.GetStride(I4),
desc.GetStride(I5),
desc.GetStride(I6),
desc.GetStride(I7),
desc.GetStride(I8),
desc.GetStride(I9),
desc.GetMemoryRank(I0),
desc.GetMemoryRank(I1),
desc.GetMemoryRank(I2),
desc.GetMemoryRank(I3),
desc.GetMemoryRank(I4),
desc.GetMemoryRank(I5),
desc.GetMemoryRank(I6),
desc.GetMemoryRank(I7),
desc.GetMemoryRank(I8),
desc.GetMemoryRank(I9));
ndim,
Lengths...,
Strides...);
});
}