use constant tensor descriptor

This commit is contained in:
Chao Liu
2018-11-04 04:08:51 -06:00
parent 9657baec32
commit 1b648f2f42
8 changed files with 913 additions and 91 deletions

View File

@@ -3,16 +3,13 @@
#include "tensor.hpp"
TensorDescriptor::TensorDescriptor(DataType_t t, std::initializer_list<std::size_t> lens)
: mLens(lens), mDataType(t)
TensorDescriptor::TensorDescriptor(std::initializer_list<std::size_t> lens) : mLens(lens)
{
this->CalculateStrides();
}
TensorDescriptor::TensorDescriptor(DataType_t t,
std::vector<std::size_t> lens,
std::vector<std::size_t> strides)
: mLens(lens), mStrides(strides), mDataType(t)
TensorDescriptor::TensorDescriptor(std::vector<std::size_t> lens, std::vector<std::size_t> strides)
: mLens(lens), mStrides(strides)
{
}
@@ -28,8 +25,6 @@ void TensorDescriptor::CalculateStrides()
mLens.rbegin(), mLens.rend() - 1, mStrides.rbegin() + 1, std::multiplies<std::size_t>());
}
DataType_t TensorDescriptor::GetDataType() const { return mDataType; }
std::size_t TensorDescriptor::GetDimension() const { return mLens.size(); }
std::size_t TensorDescriptor::GetElementSize() const