mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 21:09:08 +00:00
refactor
This commit is contained in:
@@ -83,9 +83,7 @@ struct BlockwiseGenericTensorSliceCopy_v1
|
||||
// divide work
|
||||
constexpr auto data_per_cluster_per_dims = SubLengths{} * DataClusterLengths{};
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto IDim_) {
|
||||
constexpr auto IDim = decltype(IDim_){};
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto IDim) {
|
||||
static_assert(SliceLengths::Get(IDim) % SubLengths::Get(IDim) == 0,
|
||||
"wrong! cannot evenly divide sliced tensor into sub-tensor");
|
||||
|
||||
@@ -95,9 +93,7 @@ struct BlockwiseGenericTensorSliceCopy_v1
|
||||
|
||||
// for now, only support SubLengths == 1 on a merged dimension that constains
|
||||
// multiple original dimensions
|
||||
static_for<0, nDim, 1>{}([&](auto IDim_) {
|
||||
constexpr auto IDim = decltype(IDim_){};
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto IDim) {
|
||||
static_assert(SubLengths::Get(IDim) == 1 ||
|
||||
(!SrcDesc::ContainMultipleOriginalDimensions(IDim) &&
|
||||
!DstDesc::ContainMultipleOriginalDimensions(IDim)),
|
||||
@@ -121,8 +117,7 @@ struct BlockwiseGenericTensorSliceCopy_v1
|
||||
dst_block_data_multi_id_begin + thread_data_multi_id_begin);
|
||||
|
||||
// partial offset on each dimension
|
||||
static_for<0, nDim, 1>{}([&](auto IDim_) {
|
||||
constexpr auto IDim = decltype(IDim_){};
|
||||
static_for<0, nDim, 1>{}([&](auto IDim) {
|
||||
constexpr index_t idim = IDim;
|
||||
|
||||
constexpr auto src_partial_original_dims =
|
||||
@@ -135,8 +130,7 @@ struct BlockwiseGenericTensorSliceCopy_v1
|
||||
extract_array(mThreadSrcOriginalMultiId, src_partial_original_dims));
|
||||
});
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto IDim_) {
|
||||
constexpr auto IDim = decltype(IDim_){};
|
||||
static_for<0, nDim, 1>{}([&](auto IDim) {
|
||||
constexpr index_t idim = IDim;
|
||||
|
||||
constexpr auto dst_partial_original_dims =
|
||||
@@ -208,6 +202,13 @@ struct BlockwiseGenericTensorSliceCopy_v1
|
||||
thread_tensor_desc.GetOffsetFromMultiIndex(clipboard_data_multi_id_begin);
|
||||
#endif
|
||||
|
||||
// By position the origin of the per-thread window at the point, where multi-index
|
||||
// of the SrcDesc (might be a merged tensor) is all-zero. This threadwise slice copy
|
||||
// is assuming each thread is copy a noraml (not merged) tensor.
|
||||
// User need to guarantee this is true.
|
||||
// By setting SubLengths = 1 at the merged dimension, this is always true;
|
||||
// If in the future, you want to enable SubLengths > 1 at the merged dimension,
|
||||
// special care in implementation is needed
|
||||
threadwise_generic_tensor_slice_copy_v1(SrcDesc{},
|
||||
p_src + src_offset + mThreadSrcOffset,
|
||||
make_zero_array<index_t, nDim>(),
|
||||
@@ -259,6 +260,13 @@ struct BlockwiseGenericTensorSliceCopy_v1
|
||||
const index_t dst_offset = DstDesc{}.GetOffsetFromMultiIndex(dst_data_multi_id_begin);
|
||||
#endif
|
||||
|
||||
// By position the origin of the per-thread window at the point, where multi-index
|
||||
// of the SrcDesc (might be a merged tensor) is all-zero. This threadwise slice copy
|
||||
// is assuming each thread is copy a noraml (not merged) tensor.
|
||||
// User need to guarantee this is true.
|
||||
// By setting SubLengths = 1 at the merged dimension, this is always true;
|
||||
// If in the future, you want to enable SubLengths > 1 at the merged dimension,
|
||||
// special care in implementation is needed
|
||||
threadwise_generic_tensor_slice_copy_v1(thread_tensor_desc,
|
||||
p_clipboard + clipboard_offset,
|
||||
make_zero_array<index_t, nDim>(),
|
||||
@@ -292,8 +300,7 @@ struct BlockwiseGenericTensorSliceCopy_v1
|
||||
__device__ void MoveSlicingWindowOnSourceTensor(
|
||||
Number<IDim_>, Number<StepSize>, integral_constant<bool, PositiveDirection> direction)
|
||||
{
|
||||
constexpr auto IDim = Number<IDim_>{};
|
||||
constexpr index_t idim = IDim;
|
||||
constexpr auto IDim = Number<IDim_>{};
|
||||
|
||||
static_if<SrcDesc::ContainMultipleOriginalDimensions(IDim)>{}([&](auto) {
|
||||
// logic for a merged dimension, also works for non-merged dimension, but its logic may
|
||||
@@ -316,22 +323,21 @@ struct BlockwiseGenericTensorSliceCopy_v1
|
||||
old_src_partial_original_multi_id, StepSize, direction);
|
||||
|
||||
// update "mThreadSrcOriginalMultiId"
|
||||
static_for<0, decltype(src_partial_original_dims)::GetSize(), 1>{}([&](auto I_) {
|
||||
constexpr auto I = decltype(I_){};
|
||||
constexpr index_t idim_original = src_partial_original_dims.Get(I);
|
||||
static_for<0, decltype(src_partial_original_dims)::GetSize(), 1>{}([&](auto I) {
|
||||
constexpr auto IDimOriginal = src_partial_original_dims[I];
|
||||
|
||||
mThreadSrcOriginalMultiId(idim_original) = new_src_partial_original_multi_id[I];
|
||||
mThreadSrcOriginalMultiId(IDimOriginal) = new_src_partial_original_multi_id[I];
|
||||
});
|
||||
|
||||
// calculate new partial offset on this merged dimension
|
||||
const index_t old_src_partial_offset = mThreadSrcPartialOffsets[idim];
|
||||
const index_t old_src_partial_offset = mThreadSrcPartialOffsets[IDim];
|
||||
|
||||
const index_t new_src_partial_offset =
|
||||
src_partial_original_desc.GetOffsetFromMultiIndex(
|
||||
new_src_partial_original_multi_id);
|
||||
|
||||
// update "mThreadSrcPartialOffsets"
|
||||
mThreadSrcPartialOffsets(idim) = new_src_partial_offset;
|
||||
mThreadSrcPartialOffsets(IDim) = new_src_partial_offset;
|
||||
|
||||
// update "mThreadSrcOffset", do "+" before "-" to avoid underflow
|
||||
mThreadSrcOffset = (mThreadSrcOffset + new_src_partial_offset) - old_src_partial_offset;
|
||||
@@ -346,20 +352,20 @@ struct BlockwiseGenericTensorSliceCopy_v1
|
||||
// of the boundary of the tensor being sliced. Otherwise, there might be hazard like
|
||||
// unsigned integer underflow. That is NO runtime sanity check to prevent the hazard
|
||||
|
||||
constexpr index_t idim_original = SrcDesc::GetContainedOriginalDimensions(IDim).Front();
|
||||
constexpr auto IDimOriginal = SrcDesc::GetContainedOriginalDimensions(IDim).Front();
|
||||
|
||||
static_if<PositiveDirection>{}([&](auto fwd) {
|
||||
mThreadSrcOffset += StepSize * fwd(SrcDesc{}).GetStride(IDim);
|
||||
|
||||
mThreadSrcOriginalMultiId(idim_original) += StepSize;
|
||||
mThreadSrcOriginalMultiId(IDimOriginal) += StepSize;
|
||||
|
||||
mThreadSrcPartialOffsets(idim) += StepSize * fwd(SrcDesc{}).GetStride(IDim);
|
||||
mThreadSrcPartialOffsets(IDim) += StepSize * fwd(SrcDesc{}).GetStride(IDim);
|
||||
}).Else([&](auto fwd) {
|
||||
mThreadSrcOffset -= StepSize * fwd(SrcDesc{}).GetStride(IDim);
|
||||
|
||||
mThreadSrcOriginalMultiId(idim_original) -= StepSize;
|
||||
mThreadSrcOriginalMultiId(IDimOriginal) -= StepSize;
|
||||
|
||||
mThreadSrcPartialOffsets(idim) -= StepSize * fwd(SrcDesc{}).GetStride(IDim);
|
||||
mThreadSrcPartialOffsets(IDim) -= StepSize * fwd(SrcDesc{}).GetStride(IDim);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user