debugging implicit gemm v1: use 10d tensor output

This commit is contained in:
Chao Liu
2019-04-08 10:27:32 -05:00
parent 90abf42799
commit c9fa46af0b
17 changed files with 324 additions and 178 deletions

View File

@@ -44,6 +44,32 @@ __host__ __device__ constexpr auto
1>{};
}
// this is ugly, only for 8d
template <index_t L0,
index_t L1,
index_t L2,
index_t L3,
index_t L4,
index_t L5,
index_t L6,
index_t L7,
index_t L8,
index_t L9>
__host__ __device__ constexpr auto
calculate_default_strides(Sequence<L0, L1, L2, L3, L4, L5, L6, L7, L8, L9>)
{
return Sequence<L1 * L2 * L3 * L4 * L5 * L6 * L7 * L8 * L9,
L2 * L3 * L4 * L5 * L6 * L7 * L8 * L9,
L3 * L4 * L5 * L6 * L7 * L8 * L9,
L4 * L5 * L6 * L7 * L8 * L9,
L5 * L6 * L7 * L8 * L9,
L6 * L7 * L8 * L9,
L7 * L8 * L9,
L8 * L9,
L9,
1>{};
}
// this is ugly, only for 2d
template <index_t L0, index_t L1, index_t Align>
__host__ __device__ constexpr auto calculate_default_strides_aligned(Sequence<L0, L1>,