mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
enabling vector load on merged dim
This commit is contained in:
@@ -13,8 +13,10 @@
|
||||
namespace ck {
|
||||
|
||||
// slice a (normal or merged) tensor, and copy it into another (normal or merged) tensor
|
||||
// memory layout (ordering of dimensions) can be different between src and dst
|
||||
// For now, only support SubLengths[...] == 1 on a merged dimension
|
||||
// memory layout (ordering of dimensions) can be different between src and dst.
|
||||
// on a merged dimension that constains multiple original dimensions,
|
||||
// its sub-length need to evenly divide the length of the last original dimension
|
||||
// so each thread is effectively reading a normal (not merged) tensor
|
||||
template <index_t BlockSize,
|
||||
class Float,
|
||||
class SrcDesc,
|
||||
@@ -75,7 +77,7 @@ struct BlockwiseGenericTensorSliceCopy_v1
|
||||
|
||||
// thread cluster
|
||||
constexpr auto thread_cluster_desc = make_ConstantTensorDescriptor_packed(
|
||||
DataClusterLengths{}.ReorderGivenNew2Old(ThreadClusterArrangeOrder{}));
|
||||
DataClusterLengths::ReorderGivenNew2Old(ThreadClusterArrangeOrder{}));
|
||||
|
||||
// BlockSize
|
||||
static_assert(BlockSize == thread_cluster_desc.GetElementSize(), "wrong! BlockSize");
|
||||
@@ -91,13 +93,23 @@ struct BlockwiseGenericTensorSliceCopy_v1
|
||||
"wrong! cannot evenly divide sliced tensor into cluster");
|
||||
});
|
||||
|
||||
// for now, only support SubLengths == 1 on a merged dimension that constains
|
||||
// multiple original dimensions
|
||||
// on a merged dimension that constains multiple original dimensions,
|
||||
// its sub-length need to evenly divide the length of the last original dimension,
|
||||
// so each thread is effectively reading a normal (not merged) tensor
|
||||
static_for<0, nDim, 1>{}([&](auto IDim) {
|
||||
static_assert(SubLengths::Get(IDim) == 1 ||
|
||||
(!SrcDesc::ContainMultipleOriginalDimensions(IDim) &&
|
||||
!DstDesc::ContainMultipleOriginalDimensions(IDim)),
|
||||
"wrong! only surpport Sub-Length == 1 on a merged dimension");
|
||||
constexpr auto sub_length = SubLengths::Get(IDim);
|
||||
|
||||
constexpr auto idim_original_src = SrcDesc::GetContainedOriginalDimensions(IDim).Back();
|
||||
static_assert(SrcDesc::GetOriginalTensorDescriptor().GetLength(idim_original_src) %
|
||||
sub_length ==
|
||||
0,
|
||||
"wrong!");
|
||||
|
||||
constexpr auto idim_original_dst = DstDesc::GetContainedOriginalDimensions(IDim).Back();
|
||||
static_assert(DstDesc::GetOriginalTensorDescriptor().GetLength(idim_original_dst) %
|
||||
sub_length ==
|
||||
0,
|
||||
"wrong!");
|
||||
});
|
||||
|
||||
// calculate mThreadSrcOffset, mThreadDstOffset
|
||||
@@ -118,28 +130,24 @@ struct BlockwiseGenericTensorSliceCopy_v1
|
||||
|
||||
// partial offset on each dimension
|
||||
static_for<0, nDim, 1>{}([&](auto IDim) {
|
||||
constexpr index_t idim = IDim;
|
||||
|
||||
constexpr auto src_partial_original_dims =
|
||||
SrcDesc::GetContainedOriginalDimensions(IDim);
|
||||
|
||||
constexpr auto src_partial_original_desc =
|
||||
SrcDesc::GetOriginalTensorDescriptor().Extract(src_partial_original_dims);
|
||||
|
||||
mThreadSrcPartialOffsets(idim) = src_partial_original_desc.GetOffsetFromMultiIndex(
|
||||
mThreadSrcPartialOffsets(IDim) = src_partial_original_desc.GetOffsetFromMultiIndex(
|
||||
extract_array(mThreadSrcOriginalMultiId, src_partial_original_dims));
|
||||
});
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto IDim) {
|
||||
constexpr index_t idim = IDim;
|
||||
|
||||
constexpr auto dst_partial_original_dims =
|
||||
DstDesc::GetContainedOriginalDimensions(IDim);
|
||||
|
||||
constexpr auto dst_partial_original_desc =
|
||||
DstDesc::GetOriginalTensorDescriptor().Extract(dst_partial_original_dims);
|
||||
|
||||
mThreadDstPartialOffsets(idim) = dst_partial_original_desc.GetOffsetFromMultiIndex(
|
||||
mThreadDstPartialOffsets(IDim) = dst_partial_original_desc.GetOffsetFromMultiIndex(
|
||||
extract_array(mThreadDstOriginalMultiId, dst_partial_original_dims));
|
||||
});
|
||||
|
||||
@@ -173,10 +181,8 @@ struct BlockwiseGenericTensorSliceCopy_v1
|
||||
constexpr auto thread_tensor_desc =
|
||||
make_ConstantTensorDescriptor_packed(thread_sub_tensor_lengths * repeat_lengths);
|
||||
|
||||
static_ford<decltype(repeat_lengths)>{}([&](auto repeat_multi_id_) {
|
||||
#if CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1
|
||||
constexpr auto repeat_multi_id = decltype(repeat_multi_id_){};
|
||||
|
||||
static_ford<decltype(repeat_lengths)>{}([&](auto repeat_multi_id) {
|
||||
constexpr auto src_thread_data_multi_id_begin =
|
||||
repeat_multi_id * data_per_cluster_per_dims;
|
||||
|
||||
@@ -189,14 +195,13 @@ struct BlockwiseGenericTensorSliceCopy_v1
|
||||
constexpr index_t clipboard_offset =
|
||||
thread_tensor_desc.GetOffsetFromMultiIndex(clipboard_data_multi_id_begin);
|
||||
#else
|
||||
constexpr auto repeat_multi_id = sequence2array(decltype(repeat_multi_id_){});
|
||||
|
||||
ford<decltype(repeat_lengths)>{}([&](auto repeat_multi_id) {
|
||||
const auto src_thread_data_multi_id_begin = repeat_multi_id * data_per_cluster_per_dims;
|
||||
|
||||
const auto clipboard_data_multi_id_begin = repeat_multi_id * thread_sub_tensor_lengths;
|
||||
|
||||
const index_t src_offset =
|
||||
SrcDesc{}.GetOffsetFromMultiIndex(src_thread_data_multi_id_begin);
|
||||
SrcDesc::GetOffsetFromMultiIndex(src_thread_data_multi_id_begin);
|
||||
|
||||
const index_t clipboard_offset =
|
||||
thread_tensor_desc.GetOffsetFromMultiIndex(clipboard_data_multi_id_begin);
|
||||
@@ -233,10 +238,8 @@ struct BlockwiseGenericTensorSliceCopy_v1
|
||||
constexpr auto thread_tensor_desc =
|
||||
make_ConstantTensorDescriptor_packed(thread_sub_tensor_lengths * repeat_lengths);
|
||||
|
||||
static_ford<decltype(repeat_lengths)>{}([&](auto repeat_multi_id_) {
|
||||
#if CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1
|
||||
constexpr auto repeat_multi_id = decltype(repeat_multi_id_){};
|
||||
|
||||
static_ford<decltype(repeat_lengths)>{}([&](auto repeat_multi_id) {
|
||||
constexpr auto clipboard_data_multi_id_begin =
|
||||
repeat_multi_id * thread_sub_tensor_lengths;
|
||||
|
||||
@@ -246,10 +249,9 @@ struct BlockwiseGenericTensorSliceCopy_v1
|
||||
thread_tensor_desc.GetOffsetFromMultiIndex(clipboard_data_multi_id_begin);
|
||||
|
||||
constexpr index_t dst_offset =
|
||||
DstDesc{}.GetOffsetFromMultiIndex(dst_data_multi_id_begin);
|
||||
DstDesc::GetOffsetFromMultiIndex(dst_data_multi_id_begin);
|
||||
#else
|
||||
constexpr auto repeat_multi_id = sequence2array(decltype(repeat_multi_id_){});
|
||||
|
||||
ford<decltype(repeat_lengths)>{}([&](auto repeat_multi_id) {
|
||||
const auto clipboard_data_multi_id_begin = repeat_multi_id * thread_sub_tensor_lengths;
|
||||
|
||||
const auto dst_data_multi_id_begin = repeat_multi_id * data_per_cluster_per_dims;
|
||||
@@ -257,7 +259,7 @@ struct BlockwiseGenericTensorSliceCopy_v1
|
||||
const index_t clipboard_offset =
|
||||
thread_tensor_desc.GetOffsetFromMultiIndex(clipboard_data_multi_id_begin);
|
||||
|
||||
const index_t dst_offset = DstDesc{}.GetOffsetFromMultiIndex(dst_data_multi_id_begin);
|
||||
const index_t dst_offset = DstDesc::GetOffsetFromMultiIndex(dst_data_multi_id_begin);
|
||||
#endif
|
||||
|
||||
// By position the origin of the per-thread window at the point, where multi-index
|
||||
|
||||
Reference in New Issue
Block a user