mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 09:16:52 +00:00
adding ConstantMergedTensorDescriptor, refactering ConstantTensorDescriptor, Sequence
This commit is contained in:
@@ -16,6 +16,8 @@ struct Array
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr index_t GetSize() const { return NSize; }
|
||||
|
||||
__host__ __device__ const TData& operator[](index_t i) const { return mData[i]; }
|
||||
|
||||
__host__ __device__ TData& operator[](index_t i) { return mData[i]; }
|
||||
@@ -67,6 +69,23 @@ __host__ __device__ auto reorder_array_given_old2new(const Array<TData, NSize>&
|
||||
return new_array;
|
||||
}
|
||||
|
||||
template <class TData, index_t NSize, class ExtractSeq>
|
||||
__host__ __device__ auto extract_array(const Array<TData, NSize>& old_array, ExtractSeq)
|
||||
{
|
||||
Array<TData, ExtractSeq::GetSize()> new_array;
|
||||
|
||||
constexpr index_t new_size = ExtractSeq::GetSize();
|
||||
|
||||
static_assert(new_size <= NSize, "wrong! too many extract");
|
||||
|
||||
static_for<0, new_size, 1>{}([&](auto I) {
|
||||
constexpr index_t i = I.Get();
|
||||
new_array[i] = old_array[ExtractSeq{}.Get(I)];
|
||||
});
|
||||
|
||||
return new_array;
|
||||
}
|
||||
|
||||
template <class TData, index_t NSize>
|
||||
__host__ __device__ constexpr auto operator+(const Array<TData, NSize>& a,
|
||||
const Array<TData, NSize>& b)
|
||||
|
||||
Reference in New Issue
Block a user