mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 09:16:52 +00:00
initial build
This commit is contained in:
@@ -3,8 +3,6 @@
|
||||
|
||||
#include "tensor.hpp"
|
||||
|
||||
TensorDescriptor::TensorDescriptor() {}
|
||||
|
||||
TensorDescriptor::TensorDescriptor(DataType_t t, std::initializer_list<std::size_t> lens)
|
||||
: mLens(lens), mDataType(t)
|
||||
{
|
||||
@@ -22,7 +20,7 @@ void TensorDescriptor::CalculateStrides()
|
||||
{
|
||||
mStrides.clear();
|
||||
mStrides.resize(mLens.size(), 0);
|
||||
if(strides.empty())
|
||||
if(mStrides.empty())
|
||||
return;
|
||||
|
||||
mStrides.back() = 1;
|
||||
@@ -41,6 +39,10 @@ std::size_t TensorDescriptor::GetElementSize() const
|
||||
|
||||
std::size_t TensorDescriptor::GetElementSpace() const
|
||||
{
|
||||
auto ls = mLens | boost::adaptor::transformed([](auto v) { return v - 1; });
|
||||
auto ls = mLens | boost::adaptors::transformed([](std::size_t v) { return v - 1; });
|
||||
return std::inner_product(ls.begin(), ls.end(), mStrides.begin(), std::size_t{0}) + 1;
|
||||
}
|
||||
|
||||
const std::vector<std::size_t>& TensorDescriptor::GetLengths() const { return mLens; }
|
||||
|
||||
const std::vector<std::size_t>& TensorDescriptor::GetStrides() const { return mStrides; }
|
||||
|
||||
Reference in New Issue
Block a user