mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-07 08:15:04 +00:00
Add length/stride getters for HostTensor
This commit is contained in:
@@ -155,7 +155,12 @@ struct HostTensorDescriptor
|
||||
return space;
|
||||
}
|
||||
|
||||
std::size_t get_length(std::size_t dim) const { return mLens[dim]; }
|
||||
|
||||
const std::vector<std::size_t>& get_lengths() const { return mLens; }
|
||||
|
||||
std::size_t get_stride(std::size_t dim) const { return mStrides[dim]; }
|
||||
|
||||
const std::vector<std::size_t>& get_strides() const { return mStrides; }
|
||||
|
||||
template <typename... Is>
|
||||
@@ -325,8 +330,12 @@ struct HostTensor
|
||||
{
|
||||
}
|
||||
|
||||
std::size_t get_legnth(std::size_t dim) const { return mDesc.get_length(dim); }
|
||||
|
||||
decltype(auto) get_lengths() const { return mDesc.get_lengths(); }
|
||||
|
||||
std::size_t get_stride(std::size_t dim) const { return mDesc.get_stride(dim); }
|
||||
|
||||
decltype(auto) get_strides() const { return mDesc.get_strides(); }
|
||||
|
||||
std::size_t get_num_of_dimension() const { return mDesc.get_num_of_dimension(); }
|
||||
|
||||
Reference in New Issue
Block a user