mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 17:55:48 +00:00
adding tensor_view
This commit is contained in:
@@ -5,6 +5,7 @@
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
#include "ConstantMergedTensorDescriptor.hpp"
|
||||
#include "tensor_coordinate.hpp"
|
||||
#include "tensor_view.hpp"
|
||||
#include "threadwise_generic_tensor_slice_copy.hpp"
|
||||
|
||||
#ifndef CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1
|
||||
@@ -442,12 +443,13 @@ struct BlockwiseGenericTensorSliceCopy_v2
|
||||
__device__ constexpr BlockwiseGenericTensorSliceCopy_v2(SrcCoordinate src_block_slice_origin,
|
||||
DstCoordinate dst_block_slice_origin)
|
||||
{
|
||||
static_assert(nDim == SrcDesc::GetNumOfDimension() &&
|
||||
nDim == DstDesc::GetNumOfDimension() && nDim == SliceLengths::GetSize() &&
|
||||
nDim == SubLengths::GetSize() &&
|
||||
nDim == ThreadClusterLengths::GetSize() &&
|
||||
nDim == ThreadClusterArrangeOrder::GetSize(),
|
||||
"wrong! nDim not consistent");
|
||||
static_assert(
|
||||
nDim == SrcDesc::GetNumOfDimension() && nDim == DstDesc::GetNumOfDimension() &&
|
||||
nDim == SliceLengths::GetSize() && nDim == SubLengths::GetSize() &&
|
||||
nDim == ThreadClusterLengths::GetSize() &&
|
||||
nDim == ThreadClusterArrangeOrder::GetSize() &&
|
||||
nDim == SrcDimAccessOrder::GetSize() && nDim == DstDimAccessOrder::GetSize(),
|
||||
"wrong! nDim not consistent");
|
||||
|
||||
static_assert(is_same<SliceLengths, decltype(SubLengths{} * ThreadClusterLengths{})>{},
|
||||
"wrong! threads should be mapped to cover entire slicing window");
|
||||
@@ -542,6 +544,129 @@ struct BlockwiseGenericTensorSliceCopy_v2
|
||||
ThreadwiseStore mThreadwiseStore;
|
||||
};
|
||||
|
||||
template <index_t BlockSize,
|
||||
class SrcTensor,
|
||||
class DstTensor,
|
||||
class SliceLengths,
|
||||
class SubLengths,
|
||||
class ThreadClusterLengths,
|
||||
class ThreadClusterArrangeOrder,
|
||||
class SrcDimAccessOrder,
|
||||
class DstDimAccessOrder,
|
||||
index_t SrcVectorAccessDim,
|
||||
index_t DstVectorAccessDim,
|
||||
index_t SrcDataPerAccess,
|
||||
index_t DstDataPerAccess>
|
||||
struct BlockwiseGenericTensorSliceCopy_v3
|
||||
{
|
||||
static constexpr index_t nDim = SrcTensor::GetNumOfDimension();
|
||||
using data_type = remove_cv_t<typename SrcTensor::data_type>;
|
||||
|
||||
using SrcCoordinate = typename SrcTensor::coordinate_type;
|
||||
using DstCoordinate = typename DstTensor::coordinate_type;
|
||||
|
||||
__device__ constexpr BlockwiseGenericTensorSliceCopy_v3(SrcTensor src_block,
|
||||
SrcCoordinate src_block_slice_origin,
|
||||
DstTensor dst_block,
|
||||
DstCoordinate dst_block_slice_origin)
|
||||
: mThreadBuffer{make_TensorView(ThreadBufferDesc{}, mpBuffer)}
|
||||
{
|
||||
static_assert(
|
||||
nDim == SrcTensor::GetNumOfDimension() && nDim == DstTensor::GetNumOfDimension() &&
|
||||
nDim == SliceLengths::GetSize() && nDim == SubLengths::GetSize() &&
|
||||
nDim == ThreadClusterLengths::GetSize() &&
|
||||
nDim == ThreadClusterArrangeOrder::GetSize() &&
|
||||
nDim == SrcDimAccessOrder::GetSize() && nDim == DstDimAccessOrder::GetSize(),
|
||||
"wrong! nDim not consistent");
|
||||
|
||||
static_assert(is_same<SliceLengths, decltype(SubLengths{} * ThreadClusterLengths{})>{},
|
||||
"wrong! threads should be mapped to cover entire slicing window");
|
||||
|
||||
static_assert(is_same<remove_cv_t<typename SrcTensor::data_type>,
|
||||
remove_cv_t<typename DstTensor::data_type>>{},
|
||||
"wrong! type conversion not supported yet");
|
||||
|
||||
constexpr auto thread_cluster_desc = make_ConstantTensorDescriptor_packed(
|
||||
ThreadClusterLengths::ReorderGivenNew2Old(ThreadClusterArrangeOrder{}));
|
||||
|
||||
static_assert(BlockSize == thread_cluster_desc.GetElementSize(),
|
||||
"wrong! BlockSize not consistent with ThreadClusterLengths");
|
||||
|
||||
const auto thread_cluster_id =
|
||||
thread_cluster_desc.GetMultiIndexFrom1dIndex(get_thread_local_1d_id());
|
||||
|
||||
const auto data_cluster_id =
|
||||
reorder_array_given_old2new(thread_cluster_id, ThreadClusterArrangeOrder{});
|
||||
|
||||
const auto thread_data_id_begin = data_cluster_id * SubLengths{};
|
||||
|
||||
mThreadwiseLoad = ThreadwiseLoad(src_block,
|
||||
src_block_slice_origin + thread_data_id_begin,
|
||||
mThreadBuffer,
|
||||
make_zero_array<index_t, nDim>());
|
||||
|
||||
mThreadwiseStore = ThreadwiseStore(mThreadBuffer,
|
||||
make_zero_array<index_t, nDim>(),
|
||||
dst_block,
|
||||
dst_block_slice_origin + thread_data_id_begin);
|
||||
}
|
||||
|
||||
__device__ void RunLoadRegisterBuffer() { mThreadwiseLoad.Run(); }
|
||||
|
||||
__device__ void RunStoreRegisterBuffer() const { mThreadwiseStore.Run(); }
|
||||
|
||||
__device__ void Run()
|
||||
{
|
||||
mThreadwiseLoad.Run();
|
||||
mThreadwiseStore.Run();
|
||||
}
|
||||
|
||||
template <class T, bool PositiveDirection>
|
||||
__device__ void
|
||||
MoveSrcSliceWindow(T step_sizes, integral_constant<bool, PositiveDirection> positive_direction)
|
||||
{
|
||||
mThreadwiseLoad.MoveSrcSliceWindow(step_sizes, positive_direction);
|
||||
}
|
||||
|
||||
template <class T, bool PositiveDirection>
|
||||
__device__ void
|
||||
MoveDstSliceWindow(T step_sizes, integral_constant<bool, PositiveDirection> positive_direction)
|
||||
{
|
||||
mThreadwiseStore.MoveDstSliceWindow(step_sizes, positive_direction);
|
||||
}
|
||||
|
||||
private:
|
||||
using ThreadBufferDesc = decltype(make_ConstantTensorDescriptor_packed(SubLengths{}));
|
||||
using ThreadBufferTensor = NormalTensorView<ThreadBufferDesc, data_type>;
|
||||
|
||||
using ThreadwiseLoad = ThreadwiseGenericTensorSliceCopy_v3<SrcTensor,
|
||||
ThreadBufferTensor,
|
||||
SubLengths,
|
||||
SrcDimAccessOrder,
|
||||
SrcDimAccessOrder,
|
||||
SrcVectorAccessDim,
|
||||
SrcVectorAccessDim,
|
||||
SrcDataPerAccess,
|
||||
1>;
|
||||
|
||||
using ThreadwiseStore = ThreadwiseGenericTensorSliceCopy_v3<ThreadBufferTensor,
|
||||
DstTensor,
|
||||
SubLengths,
|
||||
DstDimAccessOrder,
|
||||
DstDimAccessOrder,
|
||||
DstVectorAccessDim,
|
||||
DstVectorAccessDim,
|
||||
1,
|
||||
DstDataPerAccess>;
|
||||
|
||||
data_type mpBuffer[ThreadBufferDesc::GetElementSpace()];
|
||||
|
||||
ThreadBufferTensor mThreadBuffer;
|
||||
|
||||
ThreadwiseLoad mThreadwiseLoad;
|
||||
ThreadwiseStore mThreadwiseStore;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
#include "ConstantMergedTensorDescriptor.hpp"
|
||||
#include "tensor_coordinate.hpp"
|
||||
#include "tensor_view.hpp"
|
||||
|
||||
#ifndef CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R1
|
||||
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R1 0
|
||||
@@ -773,5 +774,171 @@ struct ThreadwiseGenericTensorSliceCopy_v2r1
|
||||
DstCoordinate mDstSliceOrigin;
|
||||
};
|
||||
|
||||
template <class SrcTensor,
|
||||
class DstTensor,
|
||||
class SliceLengths,
|
||||
class SrcDimAccessOrder,
|
||||
class DstDimAccessOrder,
|
||||
index_t SrcVectorAccessDim,
|
||||
index_t DstVectorAccessDim,
|
||||
index_t SrcDataPerAccess,
|
||||
index_t DstDataPerAccess>
|
||||
struct ThreadwiseGenericTensorSliceCopy_v3
|
||||
{
|
||||
static constexpr index_t nDim = SrcTensor::GetNumOfDimension();
|
||||
using data_type = remove_cv_t<typename SrcTensor::data_type>;
|
||||
|
||||
using SrcCoordinate = typename SrcTensor::coordinate_type;
|
||||
using DstCoordinate = typename DstTensor::coordinate_type;
|
||||
|
||||
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v3(SrcTensor src,
|
||||
SrcCoordinate src_slice_origin,
|
||||
DstTensor dst,
|
||||
DstCoordinate dst_slice_origin)
|
||||
: mSrc{src},
|
||||
mDst{dst},
|
||||
mSrcSlice{src.Slice(src_slice_origin, SliceLengths{})},
|
||||
mDstSlice{dst.Slice(dst_slice_origin, SliceLengths{})}
|
||||
{
|
||||
static_assert(nDim == SrcTensor::GetNumOfDimension() &&
|
||||
nDim == DstTensor::GetNumOfDimension() &&
|
||||
nDim == SliceLengths::GetSize() && nDim == SrcDimAccessOrder::GetSize() &&
|
||||
nDim == DstDimAccessOrder::GetSize(),
|
||||
"wrong! # of dimensions not the same");
|
||||
|
||||
static_assert(is_valid_sequence_map<SrcDimAccessOrder>::value &&
|
||||
is_valid_sequence_map<DstDimAccessOrder>::value,
|
||||
"wrong! map is not valid");
|
||||
|
||||
static_assert(is_same<remove_cv_t<typename SrcTensor::data_type>,
|
||||
remove_cv_t<typename DstTensor::data_type>>{},
|
||||
"wrong! type conversion is not supported yet");
|
||||
|
||||
static_assert(decltype(mSrcSlice)::IsVectorizationAllowed(Number<SrcVectorAccessDim>{},
|
||||
Number<SrcDataPerAccess>{}) &&
|
||||
decltype(mDstSlice)::IsVectorizationAllowed(Number<DstVectorAccessDim>{},
|
||||
Number<DstDataPerAccess>{}),
|
||||
"wrong! vectorized access is not allowed");
|
||||
}
|
||||
|
||||
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v3()
|
||||
: ThreadwiseGenericTensorSliceCopy_v3(
|
||||
SrcTensor{}, SrcCoordinate{}, DstTensor{}, DstCoordinate{})
|
||||
{
|
||||
}
|
||||
|
||||
__device__ void Run() const
|
||||
{
|
||||
// buffer
|
||||
constexpr auto buffer_desc = make_ConstantTensorDescriptor_packed(SrcTensor::GetLengths());
|
||||
data_type p_buffer[buffer_desc.GetElementSpace()];
|
||||
auto buffer = make_TensorView(buffer_desc, p_buffer);
|
||||
|
||||
// copy data from src into buffer
|
||||
{
|
||||
using src_vector_t = typename vector_type<data_type, SrcDataPerAccess>::MemoryType;
|
||||
|
||||
constexpr auto src_vector_access_dim = Number<SrcVectorAccessDim>{};
|
||||
constexpr auto src_data_per_access = Number<SrcDataPerAccess>{};
|
||||
|
||||
auto src_slice_vectorized =
|
||||
mSrcSlice.Vectorize(src_vector_access_dim, src_data_per_access);
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor("mSrcSlice: ", typename decltype(mSrcSlice)::tensor_desc_type{});
|
||||
print_ConstantTensorDescriptor("src_slice_vector: ", typename decltype(src_slice_vectorized)::tensor_desc_type{});
|
||||
}
|
||||
#endif
|
||||
|
||||
#if 1 // debug
|
||||
ford<decltype(src_slice_vectorized.GetLengths()), SrcDimAccessOrder>{}(
|
||||
[&](auto src_vector_id) {
|
||||
// load vector from src
|
||||
const src_vector_t vector_data = src_slice_vectorized[src_vector_id];
|
||||
|
||||
// unpack vector into buffer
|
||||
auto src_scalar_id = src_vector_id;
|
||||
src_scalar_id(src_vector_access_dim) *= src_data_per_access;
|
||||
|
||||
for(index_t i = 0; i < SrcDataPerAccess; ++i)
|
||||
{
|
||||
auto id = make_zero_array<index_t, nDim>();
|
||||
id(src_vector_access_dim) = i;
|
||||
|
||||
buffer(src_scalar_id + id) =
|
||||
reinterpret_cast<const data_type*>(&vector_data)[i];
|
||||
}
|
||||
});
|
||||
#endif
|
||||
}
|
||||
|
||||
// copy data from buffer into dst
|
||||
{
|
||||
using dst_vector_t = typename vector_type<data_type, DstDataPerAccess>::MemoryType;
|
||||
|
||||
constexpr auto dst_vector_access_dim = Number<DstVectorAccessDim>{};
|
||||
constexpr auto dst_data_per_access = Number<DstDataPerAccess>{};
|
||||
|
||||
auto dst_slice_vectorized =
|
||||
mDstSlice.Vectorize(dst_vector_access_dim, dst_data_per_access);
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor("mDstSlice: ", typename decltype(mDstSlice)::tensor_desc_type{});
|
||||
print_ConstantTensorDescriptor("dst_slice_vector: ", typename decltype(dst_slice_vectorized)::tensor_desc_type{});
|
||||
}
|
||||
#endif
|
||||
|
||||
#if 1 // debug
|
||||
ford<decltype(dst_slice_vectorized.GetLengths()), DstDimAccessOrder>{}(
|
||||
[&](auto dst_vector_id) {
|
||||
|
||||
dst_vector_t vector_data{};
|
||||
|
||||
// pack vector from buffer
|
||||
auto dst_scalar_id = dst_vector_id;
|
||||
dst_scalar_id(dst_vector_access_dim) *= dst_data_per_access;
|
||||
|
||||
for(index_t i = 0; i < DstDataPerAccess; ++i)
|
||||
{
|
||||
auto id = make_zero_array<index_t, nDim>();
|
||||
id(dst_vector_access_dim) = i;
|
||||
|
||||
reinterpret_cast<data_type*>(&vector_data)[i] = buffer[dst_scalar_id + id];
|
||||
}
|
||||
|
||||
// write vector into dst
|
||||
dst_slice_vectorized(dst_vector_id) = vector_data;
|
||||
});
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
// T can be Sequence or Array
|
||||
template <class T, bool PositiveDirection>
|
||||
__device__ void MoveSrcSliceWindow(T step_sizes, integral_constant<bool, PositiveDirection>)
|
||||
{
|
||||
mSrc.MoveSliceWindow(mSrcSlice, step_sizes, integral_constant<bool, PositiveDirection>{});
|
||||
}
|
||||
|
||||
template <class T, bool PositiveDirection>
|
||||
__device__ void MoveDstSliceWindow(T step_sizes, integral_constant<bool, PositiveDirection>)
|
||||
{
|
||||
mDst.MoveSliceWindow(mDstSlice, step_sizes, integral_constant<bool, PositiveDirection>{});
|
||||
}
|
||||
|
||||
private:
|
||||
using SrcSlice = decltype(SrcTensor{}.Slice(make_zero_array<index_t, nDim>(), SliceLengths{}));
|
||||
using DstSlice = decltype(DstTensor{}.Slice(make_zero_array<index_t, nDim>(), SliceLengths{}));
|
||||
|
||||
SrcTensor mSrc;
|
||||
DstTensor mDst;
|
||||
SrcSlice mSrcSlice;
|
||||
DstSlice mDstSlice;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user