implicit gemm v1r2: adding support for nchw

This commit is contained in:
Chao Liu
2019-04-18 11:49:09 -05:00
parent 17f3d2d4bc
commit 19f17df47a
16 changed files with 1624 additions and 220 deletions

View File

@@ -108,11 +108,11 @@ template <class Lengths, class Strides>
struct ConstantTensorDescriptor
{
using Type = ConstantTensorDescriptor<Lengths, Strides>;
static constexpr index_t nDim = Lengths::nDim;
static constexpr index_t nDim = Lengths::GetSize();
__host__ __device__ constexpr ConstantTensorDescriptor()
{
static_assert(Lengths::nDim == Strides::nDim, "nDim not consistent");
static_assert(Lengths::GetSize() == Strides::GetSize(), "nDim not consistent");
}
__host__ __device__ static constexpr index_t GetDimension() { return nDim; }
@@ -157,12 +157,10 @@ struct ConstantTensorDescriptor
return align.Get() * ((element_space_unaligned + align.Get() - 1) / align.Get());
}
template <class... Is>
__host__ __device__ static index_t Get1dIndex(Is... is)
template <index_t NSize>
__host__ __device__ static index_t Get1dIndex(Array<index_t, NSize> multi_id)
{
static_assert(sizeof...(Is) == nDim, "number of multi-index is wrong");
const auto multi_id = Array<index_t, nDim>(is...);
static_assert(NSize == nDim, "wrong! Dimension not consistent");
index_t id = 0;
@@ -178,6 +176,16 @@ struct ConstantTensorDescriptor
return id;
}
template <class... Is>
__host__ __device__ static index_t Get1dIndex(Is... is)
{
static_assert(sizeof...(Is) == nDim, "number of multi-index is wrong");
const auto multi_id = Array<index_t, nDim>(is...);
return Get1dIndex(multi_id);
}
__host__ __device__ static Array<index_t, nDim> GetMultiIndex(index_t id)
{
Array<index_t, nDim> multi_id;