tweak on amd

This commit is contained in:
Chao Liu
2019-08-08 12:14:06 -05:00
parent a9b2b1dcd7
commit 4908fe3fdc
7 changed files with 117 additions and 85 deletions

View File

@@ -402,6 +402,19 @@ struct BlockwiseGenericTensorSliceCopy_v1
});
});
}
template <class T, bool PositiveDirection>
__device__ void
MoveSrcSlicingWindow(T step_sizes,
integral_constant<bool, PositiveDirection> positive_direction)
{
static_for<0, nDim, 1>{}([&](auto idim) {
if(step_sizes[idim] != 0)
{
MoveSlicingWindowOnSourceTensor(idim, step_sizes[idim], positive_direction);
}
});
}
};
template <index_t BlockSize,
@@ -502,21 +515,6 @@ struct BlockwiseGenericTensorSliceCopy_v2
private:
using RegisterBufferDesc = decltype(make_ConstantTensorDescriptor_packed(SubLengths{}));
#if 0
using ThreadwiseLoad =
ThreadwiseGenericTensorSliceCopy_v2<SrcDesc,
RegisterBufferDesc,
SrcCoordinate,
NormalTensorCoordinate<RegisterBufferDesc>,
SubLengths>;
using ThreadwiseStore =
ThreadwiseGenericTensorSliceCopy_v2<RegisterBufferDesc,
DstDesc,
NormalTensorCoordinate<RegisterBufferDesc>,
DstCoordinate,
SubLengths>;
#else
using ThreadwiseLoad =
ThreadwiseGenericTensorSliceCopy_v2r1<SrcDesc,
RegisterBufferDesc,
@@ -542,7 +540,7 @@ struct BlockwiseGenericTensorSliceCopy_v2
DstVectorAccessDim,
1,
DstDataPerAccess>;
#endif
ThreadwiseLoad mThreadwiseLoad;
ThreadwiseStore mThreadwiseStore;
};

View File

@@ -594,7 +594,6 @@ struct ThreadwiseGenericTensorSliceCopy_v2
DstCoordinate mDstSliceOrigin;
};
#if 1
// This threadwise copy allow vector access of src and dst.
// It allows the dimensions of vector access to be different on src and dst.
// It also allows the vector size to be different on src and dst.
@@ -623,6 +622,49 @@ struct ThreadwiseGenericTensorSliceCopy_v2r1
DstCoordinate dst_slice_origin)
: mSrcSliceOrigin(src_slice_origin), mDstSliceOrigin(dst_slice_origin)
{
static_assert(nDim == SrcDesc::GetNumOfDimension() &&
nDim == DstDesc::GetNumOfDimension() && nDim == SliceLengths::GetSize() &&
nDim == SrcDimAccessOrder::GetSize() &&
nDim == DstDimAccessOrder::GetSize(),
"wrong! # of dimensions not the same");
static_assert(is_valid_sequence_map<SrcDimAccessOrder>::value &&
is_valid_sequence_map<DstDimAccessOrder>::value,
"wrong! map is not valid");
static_assert(SliceLengths{}[SrcVectorAccessDim] % SrcDataPerAccess == 0 &&
SliceLengths{}[DstVectorAccessDim] % DstDataPerAccess == 0,
"wrong! cannot evenly divide");
// check vectorized memory access
constexpr auto src_vector_access_dim = Number<SrcVectorAccessDim>{};
constexpr auto dst_vector_access_dim = Number<DstVectorAccessDim>{};
static_if<!SrcDesc::ContainMultipleOriginalDimensions(src_vector_access_dim)>{}(
[&](auto fwd) {
static_assert(
(fwd(SrcDesc{}).GetStride(src_vector_access_dim) == 1 || SrcDataPerAccess == 1),
"wrong! vectorized access is allowed only if stride == 1");
})
.Else([&](auto fwd) {
static_assert(
(fwd(SrcDesc{}).GetLastOriginalDimensionStride(src_vector_access_dim) == 1 ||
SrcDataPerAccess == 1),
"wrong! vectorized access is allowed only if stride == 1");
});
static_if<!DstDesc::ContainMultipleOriginalDimensions(dst_vector_access_dim)>{}(
[&](auto fwd) {
static_assert(
(fwd(DstDesc{}).GetStride(dst_vector_access_dim) == 1 || DstDataPerAccess == 1),
"wrong! vectorized access is allowed only if stride == 1");
})
.Else([&](auto fwd) {
static_assert(
(fwd(DstDesc{}).GetLastOriginalDimensionStride(dst_vector_access_dim) == 1 ||
DstDataPerAccess == 1),
"wrong! vectorized access is allowed only if stride == 1");
});
}
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v2r1()
@@ -725,9 +767,6 @@ struct ThreadwiseGenericTensorSliceCopy_v2r1
constexpr index_t buffer_offset = buffer_desc.GetOffsetFromMultiIndex(
src_merged_dim_data_id + src_normal_dim_data_id + scalar_id);
constexpr index_t buffer_offset =
buffer_desc.GetOffsetFromMultiIndex(src_data_begin_id + scalar_id);
p_buffer[buffer_offset] = reinterpret_cast<const TData*>(&vector_data)[i];
});
});
@@ -900,7 +939,6 @@ struct ThreadwiseGenericTensorSliceCopy_v2r1
SrcCoordinate mSrcSliceOrigin;
DstCoordinate mDstSliceOrigin;
};
#endif
} // namespace ck
#endif