mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 09:16:52 +00:00
use constant tensor descriptor
This commit is contained in:
@@ -3,16 +3,13 @@
|
||||
|
||||
#include "tensor.hpp"
|
||||
|
||||
TensorDescriptor::TensorDescriptor(DataType_t t, std::initializer_list<std::size_t> lens)
|
||||
: mLens(lens), mDataType(t)
|
||||
TensorDescriptor::TensorDescriptor(std::initializer_list<std::size_t> lens) : mLens(lens)
|
||||
{
|
||||
this->CalculateStrides();
|
||||
}
|
||||
|
||||
TensorDescriptor::TensorDescriptor(DataType_t t,
|
||||
std::vector<std::size_t> lens,
|
||||
std::vector<std::size_t> strides)
|
||||
: mLens(lens), mStrides(strides), mDataType(t)
|
||||
TensorDescriptor::TensorDescriptor(std::vector<std::size_t> lens, std::vector<std::size_t> strides)
|
||||
: mLens(lens), mStrides(strides)
|
||||
{
|
||||
}
|
||||
|
||||
@@ -28,8 +25,6 @@ void TensorDescriptor::CalculateStrides()
|
||||
mLens.rbegin(), mLens.rend() - 1, mStrides.rbegin() + 1, std::multiplies<std::size_t>());
|
||||
}
|
||||
|
||||
DataType_t TensorDescriptor::GetDataType() const { return mDataType; }
|
||||
|
||||
std::size_t TensorDescriptor::GetDimension() const { return mLens.size(); }
|
||||
|
||||
std::size_t TensorDescriptor::GetElementSize() const
|
||||
|
||||
Reference in New Issue
Block a user