implicit gemm v1r2: only load 1d filter

This commit is contained in:
Chao Liu
2019-04-13 11:19:17 -05:00
parent 96ee9571e2
commit 00899f191b
17 changed files with 426 additions and 142 deletions

View File

@@ -8,6 +8,13 @@ __host__ __device__ constexpr auto calculate_default_strides(Sequence<L0, L1>)
return Sequence<L1, 1>{};
}
// this is ugly, only for 3d
template <index_t L0, index_t L1, index_t L2>
__host__ __device__ constexpr auto calculate_default_strides(Sequence<L0, L1, L2>)
{
return Sequence<L1 * L2, L2, 1>{};
}
// this is ugly, only for 4d
template <index_t L0, index_t L1, index_t L2, index_t L3>
__host__ __device__ constexpr auto calculate_default_strides(Sequence<L0, L1, L2, L3>)
@@ -79,6 +86,15 @@ __host__ __device__ constexpr auto calculate_default_strides_aligned(Sequence<L0
return Sequence<L1_align, 1>{};
}
// this is ugly, only for 3d
template <index_t L0, index_t L1, index_t L2, index_t Align>
__host__ __device__ constexpr auto calculate_default_strides_aligned(Sequence<L0, L1, L2>,
Number<Align>)
{
constexpr index_t L2_align = Align * ((L2 + Align - 1) / Align);
return Sequence<L1 * L2_align, L2_align, 1>{};
}
// this is ugly, only for 4d
template <index_t L0, index_t L1, index_t L2, index_t L3, index_t Align>
__host__ __device__ constexpr auto calculate_default_strides_aligned(Sequence<L0, L1, L2, L3>,
@@ -244,6 +260,22 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
desc.GetStride(I0),
desc.GetStride(I1));
}
else if(ndim == 3)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
printf("%s dim %u, lengths {%u %u %u}, strides {%u %u %u}\n",
s,
desc.GetDimension(),
desc.GetLength(I0),
desc.GetLength(I1),
desc.GetLength(I2),
desc.GetStride(I0),
desc.GetStride(I1),
desc.GetStride(I2));
}
else if(ndim == 4)
{
constexpr auto I0 = Number<0>{};