mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 08:50:17 +00:00
reworked ThreadwiseGenericTensorSliceCopy_v1
This commit is contained in:
@@ -170,6 +170,8 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
|
||||
InBlockCopyThreadClusterArrangeOrder,
|
||||
InBlockCopySrcAccessOrder,
|
||||
InBlockCopyDstAccessOrder,
|
||||
2,
|
||||
3,
|
||||
InBlockCopySrcDataPerRead_B,
|
||||
InBlockCopyDstDataPerWrite_N2>(
|
||||
{0, 0, b_block_data_on_global, 0}, {0, 0, 0, 0});
|
||||
@@ -213,6 +215,8 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
|
||||
WeiBlockCopyThreadClusterArrangeOrder,
|
||||
WeiBlockCopySrcAccessOrder,
|
||||
WeiBlockCopyDstAccessOrder,
|
||||
0,
|
||||
1,
|
||||
WeiBlockCopySrcDataPerRead_E,
|
||||
WeiBlockCopyDstDataPerWrite_K>(
|
||||
{0, k_block_data_on_global}, {0, 0});
|
||||
@@ -434,7 +438,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
|
||||
out_k_n1_b_n2_global_merged_desc.GetOffsetFromMultiIndex(
|
||||
k_thread_data_on_global, 0, b_thread_data_on_global, 0);
|
||||
|
||||
#if 1
|
||||
#if 0
|
||||
threadwise_generic_tensor_slice_copy_v1(
|
||||
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc,
|
||||
p_out_thread,
|
||||
@@ -445,9 +449,20 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
|
||||
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc.GetLengths(),
|
||||
arithmetic_sequence_gen<0, 8, 1>::type{},
|
||||
Number<1>{});
|
||||
#else
|
||||
#elif 1
|
||||
ThreadwiseGenericTensorSliceCopy_v1<
|
||||
decltype(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc),
|
||||
decltype(out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc),
|
||||
decltype(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc.GetLengths()),
|
||||
arithmetic_sequence_gen<0, 8, 1>::type,
|
||||
arithmetic_sequence_gen<0, 8, 1>::type,
|
||||
0,
|
||||
0,
|
||||
1,
|
||||
1>({0, 0, 0, 0, 0, 0, 0, 0}, {0, 0, 0, 0, 0, 0, 0, 0})
|
||||
.Run(p_out_thread, p_out_thread_on_global);
|
||||
#elif 0
|
||||
ThreadwiseGenericTensorSliceCopy_v2<
|
||||
Float,
|
||||
decltype(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc),
|
||||
decltype(out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc),
|
||||
NormalTensorCoordinate<decltype(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc)>,
|
||||
|
||||
@@ -19,7 +19,8 @@ namespace ck {
|
||||
// to simplify index calculations. To satisfy this assumption, the user need to make sure
|
||||
// that, on a merged dimension that constains multiple original dimensions, the length of
|
||||
// the last original dimension need to be evenly dividable by its sub-lengths. Also, the
|
||||
// repeat-length on the merged dimension need to be 1.
|
||||
// repeat-length on the merged dimension need to be 1. These sanity checks are performed
|
||||
// in constructor of BlockwiseGenericTensorSliceCopy_v1
|
||||
template <index_t BlockSize,
|
||||
class Float,
|
||||
class SrcDesc,
|
||||
@@ -28,10 +29,12 @@ template <index_t BlockSize,
|
||||
class SubLengths,
|
||||
class ThreadClusterLengths,
|
||||
class ThreadClusterArrangeOrder,
|
||||
class SrcAccessOrder,
|
||||
class DstAccessOrder,
|
||||
index_t SrcDataPerRead,
|
||||
index_t DstDataPerWrite>
|
||||
class SrcDimAccessOrder,
|
||||
class DstDimAccessOrder,
|
||||
index_t SrcVectorAccessDim,
|
||||
index_t DstVectorAccessDim,
|
||||
index_t SrcDataPerAccess,
|
||||
index_t DstDataPerAccess>
|
||||
struct BlockwiseGenericTensorSliceCopy_v1
|
||||
{
|
||||
static constexpr index_t nDim = SrcDesc::GetNumOfDimension();
|
||||
@@ -60,23 +63,22 @@ struct BlockwiseGenericTensorSliceCopy_v1
|
||||
Array<index_t, nOriginalDimSrc> mThreadSrcOriginalMultiId;
|
||||
Array<index_t, nOriginalDimDst> mThreadDstOriginalMultiId;
|
||||
|
||||
__device__
|
||||
BlockwiseGenericTensorSliceCopy_v1(Array<index_t, nDim> src_block_data_multi_id_begin,
|
||||
Array<index_t, nDim> dst_block_data_multi_id_begin)
|
||||
__device__ BlockwiseGenericTensorSliceCopy_v1(Array<index_t, nDim> src_block_data_id_begin,
|
||||
Array<index_t, nDim> dst_block_data_id_begin)
|
||||
{
|
||||
// check NDim consistency
|
||||
static_assert(nDim == SrcDesc::GetNumOfDimension() &&
|
||||
nDim == DstDesc::GetNumOfDimension() && nDim == SliceLengths::GetSize() &&
|
||||
nDim == SubLengths::GetSize() &&
|
||||
nDim == ThreadClusterLengths::GetSize() &&
|
||||
nDim == ThreadClusterArrangeOrder::GetSize() &&
|
||||
nDim == SrcAccessOrder::GetSize() && nDim == DstAccessOrder::GetSize(),
|
||||
"wrong");
|
||||
static_assert(
|
||||
nDim == SrcDesc::GetNumOfDimension() && nDim == DstDesc::GetNumOfDimension() &&
|
||||
nDim == SliceLengths::GetSize() && nDim == SubLengths::GetSize() &&
|
||||
nDim == ThreadClusterLengths::GetSize() &&
|
||||
nDim == ThreadClusterArrangeOrder::GetSize() &&
|
||||
nDim == SrcDimAccessOrder::GetSize() && nDim == DstDimAccessOrder::GetSize(),
|
||||
"wrong");
|
||||
|
||||
// check thread arrange order and read/write access order are valid
|
||||
static_assert(is_valid_sequence_map<ThreadClusterArrangeOrder>::value &&
|
||||
is_valid_sequence_map<SrcAccessOrder>::value &&
|
||||
is_valid_sequence_map<DstAccessOrder>::value,
|
||||
is_valid_sequence_map<SrcDimAccessOrder>::value &&
|
||||
is_valid_sequence_map<DstDimAccessOrder>::value,
|
||||
"wrong!");
|
||||
|
||||
// thread cluster
|
||||
@@ -142,20 +144,20 @@ struct BlockwiseGenericTensorSliceCopy_v1
|
||||
});
|
||||
|
||||
// calculate mThreadSrcOffset, mThreadDstOffset
|
||||
const auto thread_cluster_multi_id =
|
||||
const auto thread_cluster_id =
|
||||
thread_cluster_desc.GetMultiIndexFrom1dIndex(get_thread_local_1d_id());
|
||||
|
||||
const auto data_cluster_multi_id =
|
||||
reorder_array_given_old2new(thread_cluster_multi_id, ThreadClusterArrangeOrder{});
|
||||
const auto data_cluster_id =
|
||||
reorder_array_given_old2new(thread_cluster_id, ThreadClusterArrangeOrder{});
|
||||
|
||||
const auto thread_data_multi_id_begin = data_cluster_multi_id * SubLengths{};
|
||||
const auto thread_data_id_begin = data_cluster_id * SubLengths{};
|
||||
|
||||
// original multi-id
|
||||
mThreadSrcOriginalMultiId = SrcDesc::GetOriginalMultiIndexFromMultiIndex(
|
||||
src_block_data_multi_id_begin + thread_data_multi_id_begin);
|
||||
src_block_data_id_begin + thread_data_id_begin);
|
||||
|
||||
mThreadDstOriginalMultiId = DstDesc::GetOriginalMultiIndexFromMultiIndex(
|
||||
dst_block_data_multi_id_begin + thread_data_multi_id_begin);
|
||||
dst_block_data_id_begin + thread_data_id_begin);
|
||||
|
||||
// partial offset on each dimension
|
||||
static_for<0, nDim, 1>{}([&](auto IDim) {
|
||||
@@ -188,14 +190,16 @@ struct BlockwiseGenericTensorSliceCopy_v1
|
||||
mThreadDstPartialOffsets, math::plus<index_t>{}, static_cast<index_t>(0));
|
||||
}
|
||||
|
||||
__device__ static constexpr index_t GetRegisterBufferSize()
|
||||
__device__ static constexpr auto GetRegisterBufferDescriptor()
|
||||
{
|
||||
constexpr auto repeat_lengths = SliceLengths{} / (SubLengths{} * ThreadClusterLengths{});
|
||||
|
||||
constexpr auto thread_tensor_desc =
|
||||
make_ConstantTensorDescriptor_packed(SubLengths{} * repeat_lengths);
|
||||
return make_ConstantTensorDescriptor_packed(SubLengths{} * repeat_lengths);
|
||||
}
|
||||
|
||||
return thread_tensor_desc.GetElementSpace();
|
||||
__device__ static constexpr index_t GetRegisterBufferSize()
|
||||
{
|
||||
return GetRegisterBufferDescriptor().GetElementSpace();
|
||||
}
|
||||
|
||||
__device__ void RunLoadRegisterBuffer(const Float* __restrict__ p_src,
|
||||
@@ -208,50 +212,62 @@ struct BlockwiseGenericTensorSliceCopy_v1
|
||||
|
||||
constexpr auto repeat_lengths = SliceLengths{} / (SubLengths{} * ThreadClusterLengths{});
|
||||
|
||||
constexpr auto thread_tensor_desc =
|
||||
make_ConstantTensorDescriptor_packed(thread_sub_tensor_lengths * repeat_lengths);
|
||||
constexpr auto thread_buffer_desc = GetRegisterBufferDescriptor();
|
||||
|
||||
#if CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1
|
||||
static_ford<decltype(repeat_lengths)>{}([&](auto repeat_multi_id) {
|
||||
constexpr auto src_thread_data_multi_id_begin =
|
||||
repeat_multi_id * data_per_cluster_per_dims;
|
||||
static_ford<decltype(repeat_lengths)>{}([&](auto repeat_id) {
|
||||
constexpr auto src_thread_data_id_begin = repeat_id * data_per_cluster_per_dims;
|
||||
|
||||
constexpr auto buffer_data_multi_id_begin = repeat_multi_id * thread_sub_tensor_lengths;
|
||||
constexpr auto buffer_data_id_begin = repeat_id * thread_sub_tensor_lengths;
|
||||
|
||||
constexpr index_t src_offset =
|
||||
SrcDesc::GetOffsetFromMultiIndex(src_thread_data_multi_id_begin);
|
||||
SrcDesc::GetOffsetFromMultiIndex(src_thread_data_id_begin);
|
||||
|
||||
constexpr index_t buffer_offset =
|
||||
thread_tensor_desc.GetOffsetFromMultiIndex(buffer_data_multi_id_begin);
|
||||
thread_buffer_desc.GetOffsetFromMultiIndex(buffer_data_id_begin);
|
||||
#else
|
||||
ford<decltype(repeat_lengths)>{}([&](auto repeat_multi_id) {
|
||||
const auto src_thread_data_multi_id_begin = repeat_multi_id * data_per_cluster_per_dims;
|
||||
ford<decltype(repeat_lengths)>{}([&](auto repeat_id) {
|
||||
const auto src_thread_data_id_begin = repeat_id * data_per_cluster_per_dims;
|
||||
|
||||
const auto buffer_data_multi_id_begin = repeat_multi_id * thread_sub_tensor_lengths;
|
||||
const auto buffer_data_id_begin = repeat_id * thread_sub_tensor_lengths;
|
||||
|
||||
const index_t src_offset =
|
||||
SrcDesc::GetOffsetFromMultiIndex(src_thread_data_multi_id_begin);
|
||||
const index_t src_offset = SrcDesc::GetOffsetFromMultiIndex(src_thread_data_id_begin);
|
||||
|
||||
const index_t buffer_offset =
|
||||
thread_tensor_desc.GetOffsetFromMultiIndex(buffer_data_multi_id_begin);
|
||||
thread_buffer_desc.GetOffsetFromMultiIndex(buffer_data_id_begin);
|
||||
#endif
|
||||
|
||||
// By position the origin of the per-thread window at the point, where multi-index
|
||||
// of the SrcDesc (might be a merged tensor) is all-zero. This threadwise slice copy
|
||||
// is assuming each thread is copy a noraml (not merged) tensor.
|
||||
// User need to guarantee this is true.
|
||||
// By setting SubLengths = 1 at the merged dimension, this is always true;
|
||||
// If in the future, you want to enable SubLengths > 1 at the merged dimension,
|
||||
// special care in implementation is needed
|
||||
// By position the origin of the per-thread window at the point, where multi-index
|
||||
// of the SrcDesc (might be a merged tensor) is all-zero. This threadwise slice copy
|
||||
// is assuming each thread is copy a noraml (not merged) tensor.
|
||||
// To satisfy this assumption, the user need to make sure that, on a merged dimension
|
||||
// that constains multiple original dimensions, the length of the last original
|
||||
// dimension need to be evenly dividable by its sub-lengths. Also, the repeat-length on
|
||||
// the merged dimension need to be 1. These sanity checks are performed in constructor
|
||||
// of BlockwiseGenericTensorSliceCopy_v1
|
||||
#if 0 // debug
|
||||
threadwise_generic_tensor_slice_copy_v1(SrcDesc{},
|
||||
p_src + src_offset + mThreadSrcOffset,
|
||||
make_zero_array<index_t, nDim>(),
|
||||
thread_tensor_desc,
|
||||
thread_buffer_desc,
|
||||
p_buffer + buffer_offset,
|
||||
make_zero_array<index_t, nDim>(),
|
||||
thread_sub_tensor_lengths,
|
||||
SrcAccessOrder{},
|
||||
Number<SrcDataPerRead>{});
|
||||
SrcDimAccessOrder{},
|
||||
Number<SrcDataPerAccess>{});
|
||||
#else
|
||||
ThreadwiseGenericTensorSliceCopy_v1<SrcDesc,
|
||||
decltype(thread_buffer_desc),
|
||||
SubLengths,
|
||||
SrcDimAccessOrder,
|
||||
typename arithmetic_sequence_gen<0, nDim, 1>::type,
|
||||
SrcVectorAccessDim,
|
||||
0,
|
||||
SrcDataPerAccess,
|
||||
1>(make_zero_array<index_t, nDim>(),
|
||||
make_zero_array<index_t, nDim>())
|
||||
.Run(p_src + src_offset + mThreadSrcOffset, p_buffer + buffer_offset);
|
||||
#endif
|
||||
});
|
||||
}
|
||||
|
||||
@@ -265,48 +281,60 @@ struct BlockwiseGenericTensorSliceCopy_v1
|
||||
|
||||
constexpr auto repeat_lengths = SliceLengths{} / (SubLengths{} * ThreadClusterLengths{});
|
||||
|
||||
constexpr auto thread_tensor_desc =
|
||||
make_ConstantTensorDescriptor_packed(thread_sub_tensor_lengths * repeat_lengths);
|
||||
constexpr auto thread_buffer_desc = GetRegisterBufferDescriptor();
|
||||
|
||||
#if CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1
|
||||
static_ford<decltype(repeat_lengths)>{}([&](auto repeat_multi_id) {
|
||||
constexpr auto buffer_data_multi_id_begin = repeat_multi_id * thread_sub_tensor_lengths;
|
||||
static_ford<decltype(repeat_lengths)>{}([&](auto repeat_id) {
|
||||
constexpr auto buffer_data_id_begin = repeat_id * thread_sub_tensor_lengths;
|
||||
|
||||
constexpr auto dst_data_multi_id_begin = repeat_multi_id * data_per_cluster_per_dims;
|
||||
constexpr auto dst_data_id_begin = repeat_id * data_per_cluster_per_dims;
|
||||
|
||||
constexpr index_t buffer_offset =
|
||||
thread_tensor_desc.GetOffsetFromMultiIndex(buffer_data_multi_id_begin);
|
||||
thread_buffer_desc.GetOffsetFromMultiIndex(buffer_data_id_begin);
|
||||
|
||||
constexpr index_t dst_offset =
|
||||
DstDesc::GetOffsetFromMultiIndex(dst_data_multi_id_begin);
|
||||
constexpr index_t dst_offset = DstDesc::GetOffsetFromMultiIndex(dst_data_id_begin);
|
||||
#else
|
||||
ford<decltype(repeat_lengths)>{}([&](auto repeat_multi_id) {
|
||||
const auto buffer_data_multi_id_begin = repeat_multi_id * thread_sub_tensor_lengths;
|
||||
ford<decltype(repeat_lengths)>{}([&](auto repeat_id) {
|
||||
const auto buffer_data_id_begin = repeat_id * thread_sub_tensor_lengths;
|
||||
|
||||
const auto dst_data_multi_id_begin = repeat_multi_id * data_per_cluster_per_dims;
|
||||
const auto dst_data_id_begin = repeat_id * data_per_cluster_per_dims;
|
||||
|
||||
const index_t buffer_offset =
|
||||
thread_tensor_desc.GetOffsetFromMultiIndex(buffer_data_multi_id_begin);
|
||||
thread_buffer_desc.GetOffsetFromMultiIndex(buffer_data_id_begin);
|
||||
|
||||
const index_t dst_offset = DstDesc::GetOffsetFromMultiIndex(dst_data_multi_id_begin);
|
||||
const index_t dst_offset = DstDesc::GetOffsetFromMultiIndex(dst_data_id_begin);
|
||||
#endif
|
||||
|
||||
// By position the origin of the per-thread window at the point, where multi-index
|
||||
// of the SrcDesc (might be a merged tensor) is all-zero. This threadwise slice copy
|
||||
// is assuming each thread is copy a noraml (not merged) tensor.
|
||||
// User need to guarantee this is true.
|
||||
// By setting SubLengths = 1 at the merged dimension, this is always true;
|
||||
// If in the future, you want to enable SubLengths > 1 at the merged dimension,
|
||||
// special care in implementation is needed
|
||||
threadwise_generic_tensor_slice_copy_v1(thread_tensor_desc,
|
||||
// By position the origin of the per-thread window at the point, where multi-index
|
||||
// of the SrcDesc (might be a merged tensor) is all-zero. This threadwise slice copy
|
||||
// is assuming each thread is copy a noraml (not merged) tensor.
|
||||
// User need to guarantee this is true.
|
||||
// By setting SubLengths = 1 at the merged dimension, this is always true;
|
||||
// If in the future, you want to enable SubLengths > 1 at the merged dimension,
|
||||
// special care in implementation is needed
|
||||
#if 0 // debug
|
||||
threadwise_generic_tensor_slice_copy_v1(thread_buffer_desc,
|
||||
p_buffer + buffer_offset,
|
||||
make_zero_array<index_t, nDim>(),
|
||||
DstDesc{},
|
||||
p_dst + dst_offset + mThreadDstOffset,
|
||||
make_zero_array<index_t, nDim>(),
|
||||
thread_sub_tensor_lengths,
|
||||
DstAccessOrder{},
|
||||
Number<DstDataPerWrite>{});
|
||||
DstDimAccessOrder{},
|
||||
Number<DstDataPerAccess>{});
|
||||
#else
|
||||
ThreadwiseGenericTensorSliceCopy_v1<decltype(thread_buffer_desc),
|
||||
DstDesc,
|
||||
SubLengths,
|
||||
typename arithmetic_sequence_gen<0, nDim, 1>::type,
|
||||
DstDimAccessOrder,
|
||||
0,
|
||||
DstVectorAccessDim,
|
||||
1,
|
||||
DstDataPerAccess>(make_zero_array<index_t, nDim>(),
|
||||
make_zero_array<index_t, nDim>())
|
||||
.Run(p_buffer + buffer_offset, p_dst + dst_offset + mThreadDstOffset);
|
||||
#endif
|
||||
});
|
||||
}
|
||||
|
||||
@@ -346,26 +374,25 @@ struct BlockwiseGenericTensorSliceCopy_v1
|
||||
SrcDesc::GetOriginalTensorDescriptor().Extract(src_partial_original_dims);
|
||||
|
||||
// calculate new partial original multi-id
|
||||
auto old_src_partial_original_multi_id =
|
||||
auto old_src_partial_original_id =
|
||||
extract_array(mThreadSrcOriginalMultiId, src_partial_original_dims);
|
||||
|
||||
auto new_src_partial_original_multi_id =
|
||||
auto new_src_partial_original_id =
|
||||
src_partial_original_desc.UpdateMultiIndexGivenStepSizeOf1dIndex(
|
||||
old_src_partial_original_multi_id, StepSize, direction);
|
||||
old_src_partial_original_id, StepSize, direction);
|
||||
|
||||
// update "mThreadSrcOriginalMultiId"
|
||||
static_for<0, decltype(src_partial_original_dims)::GetSize(), 1>{}([&](auto I) {
|
||||
constexpr auto IDimOriginal = src_partial_original_dims[I];
|
||||
|
||||
mThreadSrcOriginalMultiId(IDimOriginal) = new_src_partial_original_multi_id[I];
|
||||
mThreadSrcOriginalMultiId(IDimOriginal) = new_src_partial_original_id[I];
|
||||
});
|
||||
|
||||
// calculate new partial offset on this merged dimension
|
||||
const index_t old_src_partial_offset = mThreadSrcPartialOffsets[IDim];
|
||||
|
||||
const index_t new_src_partial_offset =
|
||||
src_partial_original_desc.GetOffsetFromMultiIndex(
|
||||
new_src_partial_original_multi_id);
|
||||
src_partial_original_desc.GetOffsetFromMultiIndex(new_src_partial_original_id);
|
||||
|
||||
// update "mThreadSrcPartialOffsets"
|
||||
mThreadSrcPartialOffsets(IDim) = new_src_partial_offset;
|
||||
@@ -434,19 +461,19 @@ struct BlockwiseGenericTensorSliceCopy_v2
|
||||
static_assert(BlockSize == thread_cluster_desc.GetElementSize(),
|
||||
"wrong! BlockSize not consistent with ThreadClusterLengths");
|
||||
|
||||
const auto thread_cluster_multi_id =
|
||||
const auto thread_cluster_id =
|
||||
thread_cluster_desc.GetMultiIndexFrom1dIndex(get_thread_local_1d_id());
|
||||
|
||||
const auto data_cluster_multi_id =
|
||||
reorder_array_given_old2new(thread_cluster_multi_id, ThreadClusterArrangeOrder{});
|
||||
const auto data_cluster_id =
|
||||
reorder_array_given_old2new(thread_cluster_id, ThreadClusterArrangeOrder{});
|
||||
|
||||
const auto thread_data_multi_id_begin = data_cluster_multi_id * SubLengths{};
|
||||
const auto thread_data_id_begin = data_cluster_id * SubLengths{};
|
||||
|
||||
mThreadwiseLoad.SetSrcSliceOrigin(src_block_slice_origin + thread_data_multi_id_begin);
|
||||
mThreadwiseLoad.SetSrcSliceOrigin(src_block_slice_origin + thread_data_id_begin);
|
||||
mThreadwiseLoad.SetDstSliceOrigin(make_zero_array<index_t, nDim>());
|
||||
|
||||
mThreadwiseStore.SetSrcSliceOrigin(make_zero_array<index_t, nDim>());
|
||||
mThreadwiseStore.SetDstSliceOrigin(dst_block_slice_origin + thread_data_multi_id_begin);
|
||||
mThreadwiseStore.SetDstSliceOrigin(dst_block_slice_origin + thread_data_id_begin);
|
||||
}
|
||||
|
||||
__device__ static constexpr index_t GetRegisterBufferSize()
|
||||
|
||||
@@ -106,7 +106,7 @@ __device__ void threadwise_generic_tensor_slice_copy_v1(
|
||||
#endif
|
||||
}
|
||||
|
||||
#if 0
|
||||
#if 1
|
||||
template <class SrcDesc,
|
||||
class DstDesc,
|
||||
class SliceLengths,
|
||||
@@ -118,7 +118,7 @@ template <class SrcDesc,
|
||||
index_t DstDataPerAccess>
|
||||
struct ThreadwiseGenericTensorSliceCopy_v1
|
||||
{
|
||||
static constexpr index_t nDim = SliceLengths::GetNumOfDimension();
|
||||
static constexpr index_t nDim = SliceLengths::GetSize();
|
||||
|
||||
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v1(Array<index_t, nDim> src_slice_origin,
|
||||
Array<index_t, nDim> dst_slice_origin)
|
||||
@@ -130,39 +130,43 @@ struct ThreadwiseGenericTensorSliceCopy_v1
|
||||
nDim == DstDimAccessOrder::GetSize(),
|
||||
"wrong! # of dimensions not the same");
|
||||
|
||||
static_assert(is_valid_sequence_map<SrcDimAccessOrder>::{} &&
|
||||
is_valid_sequence_map<DstDimAccessOrder>::{},
|
||||
static_assert(is_valid_sequence_map<SrcDimAccessOrder>::value &&
|
||||
is_valid_sequence_map<DstDimAccessOrder>::value,
|
||||
"wrong! map is not valid");
|
||||
|
||||
static_assert(SliceLengths{}[SrcVectorDim] % SrcDataPerAccess == 0 &&
|
||||
SliceLengths{DstVectorDim} % DstDataPerAccess == 0,
|
||||
static_assert(SliceLengths{}[SrcVectorAccessDim] % SrcDataPerAccess == 0 &&
|
||||
SliceLengths{}[DstVectorAccessDim] % DstDataPerAccess == 0,
|
||||
"wrong! cannot evenly divide");
|
||||
|
||||
// check vectorized memory access
|
||||
constexpr auto src_vector_access_dim = Number<SrcVectorAccessDIm>{};
|
||||
constexpr auto dst_vector_access_dim = Number<DstVectorAccessDIm>{};
|
||||
constexpr auto src_vector_access_dim = Number<SrcVectorAccessDim>{};
|
||||
constexpr auto dst_vector_access_dim = Number<DstVectorAccessDim>{};
|
||||
|
||||
static_if<!SrcDesc::ContainMultipleOriginalDimensions(
|
||||
src_vector_access_dim)>{}([&](auto fwd) {
|
||||
static_assert(
|
||||
(fwd(SrcDesc{}).GetStrides()[SrcVectorAccessDim] == 1 || SrcDataPerAccess == 1),
|
||||
"wrong! vectorized access is allowed only if stride == 1");
|
||||
}).Else{}([&](auto fwd) {
|
||||
static_assert((SrcDesc::GetLastOriginalDimensionStride(src_vector_access_dim) == 1 ||
|
||||
SrcDataPerAccess == 1),
|
||||
"wrong! vectorized access is allowed only if stride == 1");
|
||||
});
|
||||
static_if<!SrcDesc::ContainMultipleOriginalDimensions(src_vector_access_dim)>{}(
|
||||
[&](auto fwd) {
|
||||
static_assert(
|
||||
(fwd(SrcDesc{}).GetStrides()[SrcVectorAccessDim] == 1 || SrcDataPerAccess == 1),
|
||||
"wrong! vectorized access is allowed only if stride == 1");
|
||||
})
|
||||
.Else([&](auto fwd) {
|
||||
static_assert(
|
||||
(fwd(SrcDesc{}).GetLastOriginalDimensionStride(src_vector_access_dim) == 1 ||
|
||||
SrcDataPerAccess == 1),
|
||||
"wrong! vectorized access is allowed only if stride == 1");
|
||||
});
|
||||
|
||||
static_if<!DstDesc::ContainMultipleOriginalDimensions(
|
||||
dst_vector_access_dim)>{}([&](auto fwd) {
|
||||
static_assert(
|
||||
(fwd(DstDesc{}).GetStrides()[DstVectorAccessDim] == 1 || DstDataPerAccess == 1),
|
||||
"wrong! vectorized access is allowed only if stride == 1");
|
||||
}).Else{}([&](auto fwd) {
|
||||
static_assert((DstDesc::GetLastOriginalDimensionStride(dst_vector_access_dim) == 1 ||
|
||||
DstDataPerAccess == 1),
|
||||
"wrong! vectorized access is allowed only if stride == 1");
|
||||
});
|
||||
static_if<!DstDesc::ContainMultipleOriginalDimensions(dst_vector_access_dim)>{}(
|
||||
[&](auto fwd) {
|
||||
static_assert(
|
||||
(fwd(DstDesc{}).GetStrides()[DstVectorAccessDim] == 1 || DstDataPerAccess == 1),
|
||||
"wrong! vectorized access is allowed only if stride == 1");
|
||||
})
|
||||
.Else([&](auto fwd) {
|
||||
static_assert(
|
||||
(fwd(DstDesc{}).GetLastOriginalDimensionStride(dst_vector_access_dim) == 1 ||
|
||||
DstDataPerAccess == 1),
|
||||
"wrong! vectorized access is allowed only if stride == 1");
|
||||
});
|
||||
}
|
||||
|
||||
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v1()
|
||||
@@ -186,23 +190,87 @@ struct ThreadwiseGenericTensorSliceCopy_v1
|
||||
{
|
||||
constexpr auto buffer_desc = make_ConstantTensorDescriptor_packed(SliceLengths{});
|
||||
|
||||
TData p_buffer[buffer_desc.GetElementSpace()];
|
||||
TData p_buffer_[buffer_desc.GetElementSpace()];
|
||||
TData* p_buffer = p_buffer_;
|
||||
|
||||
// copy data from src into buffer
|
||||
constexpr auto src_vector_access_dim = Number<SrcVectorAccessDIm>{};
|
||||
{
|
||||
using vector_t = typename vector_type<TData, SrcDataPerAccess>::MemoryType;
|
||||
|
||||
constexpr auto src_access_lengths = SliceLengths::Modify(
|
||||
src_vector_access_dim, SliceLengths::Get(src_vector_access_dim) / SrcDataPerAccess);
|
||||
constexpr auto src_vector_access_dim = Number<SrcVectorAccessDim>{};
|
||||
constexpr auto src_data_per_access = Number<SrcDataPerAccess>{};
|
||||
|
||||
constexpr auto src_access_lengths_in_src_access_order =
|
||||
src_access_lengths.ReorderGivenNew2Old(SrcDimAccessOrder{});
|
||||
constexpr auto src_access_lengths = SliceLengths::Modify(
|
||||
src_vector_access_dim,
|
||||
SliceLengths::Get(src_vector_access_dim) / src_data_per_access);
|
||||
|
||||
static_ford<decltype(src_access_lengths_in_src_access_order)>{}([&](auto src_access_id) {});
|
||||
static_ford<decltype(src_access_lengths), SrcDimAccessOrder>{}([&](auto src_access_id) {
|
||||
constexpr auto src_data_id = src_access_id.Modify(
|
||||
src_vector_access_dim,
|
||||
src_access_id[src_vector_access_dim] * src_data_per_access);
|
||||
|
||||
const index_t src_offset =
|
||||
SrcDesc::GetOffsetFromMultiIndex(mSrcSliceOrigin + src_data_id);
|
||||
|
||||
// load vector from src
|
||||
const vector_t vector_data = *reinterpret_cast<const vector_t*>(&p_src[src_offset]);
|
||||
|
||||
// unpack vector into buffer
|
||||
static_for<0, SrcDataPerAccess, 1>{}([&](auto i) {
|
||||
constexpr auto scalar_id =
|
||||
typename uniform_sequence_gen<nDim, 0>::type{}.Modify(src_vector_access_dim,
|
||||
i);
|
||||
|
||||
constexpr index_t buffer_offset =
|
||||
buffer_desc.GetOffsetFromMultiIndex(src_data_id + scalar_id);
|
||||
|
||||
p_buffer[buffer_offset] = reinterpret_cast<const TData*>(&vector_data)[i];
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// copy data from buffer to dst
|
||||
{
|
||||
using vector_t = typename vector_type<TData, DstDataPerAccess>::MemoryType;
|
||||
|
||||
constexpr auto dst_vector_access_dim = Number<DstVectorAccessDim>{};
|
||||
constexpr auto dst_data_per_access = Number<DstDataPerAccess>{};
|
||||
|
||||
constexpr auto dst_access_lengths = SliceLengths::Modify(
|
||||
dst_vector_access_dim,
|
||||
SliceLengths::Get(dst_vector_access_dim) / dst_data_per_access);
|
||||
|
||||
static_ford<decltype(dst_access_lengths), DstDimAccessOrder>{}([&](auto dst_access_id) {
|
||||
constexpr auto dst_data_id = dst_access_id.Modify(
|
||||
dst_vector_access_dim,
|
||||
dst_access_id[dst_vector_access_dim] * dst_data_per_access);
|
||||
|
||||
vector_t vector_data;
|
||||
|
||||
// pack vector from buffer
|
||||
static_for<0, DstDataPerAccess, 1>{}([&](auto i) {
|
||||
constexpr auto scalar_id =
|
||||
typename uniform_sequence_gen<nDim, 0>::type{}.Modify(dst_vector_access_dim,
|
||||
i);
|
||||
|
||||
constexpr index_t buffer_offset =
|
||||
buffer_desc.GetOffsetFromMultiIndex(dst_data_id + scalar_id);
|
||||
|
||||
reinterpret_cast<TData*>(&vector_data)[i] = p_buffer[buffer_offset];
|
||||
});
|
||||
|
||||
const index_t dst_offset =
|
||||
DstDesc::GetOffsetFromMultiIndex(mDstSliceOrigin + dst_data_id);
|
||||
|
||||
// store vector into dst
|
||||
*reinterpret_cast<vector_t*>(&p_dst[dst_offset]) = vector_data;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
Array<index_t, TData> mSrcSliceOrigin;
|
||||
Array<index_t, TData> mDstSliceOrigin;
|
||||
Array<index_t, nDim> mSrcSliceOrigin;
|
||||
Array<index_t, nDim> mDstSliceOrigin;
|
||||
};
|
||||
#endif
|
||||
|
||||
|
||||
@@ -23,14 +23,16 @@ struct static_for_impl<Sequence<Is...>>
|
||||
template <index_t NBegin, index_t NEnd, index_t Increment>
|
||||
struct static_for
|
||||
{
|
||||
__host__ __device__ constexpr static_for()
|
||||
{
|
||||
static_assert(NBegin <= NEnd, "wrongs! should have NBegin <= NEnd");
|
||||
static_assert((NEnd - NBegin) % Increment == 0,
|
||||
"Wrong! should satisfy (NEnd - NBegin) % Increment == 0");
|
||||
}
|
||||
|
||||
template <class F>
|
||||
__host__ __device__ constexpr void operator()(F f) const
|
||||
{
|
||||
static_assert(NBegin <= NEnd, "wrongs! should have NBegin <= NEnd");
|
||||
|
||||
static_assert((NEnd - NBegin) % Increment == 0,
|
||||
"Wrong! should satisfy (NEnd - NBegin) % Increment == 0");
|
||||
|
||||
static_for_impl<typename arithmetic_sequence_gen<NBegin, NEnd, Increment>::type>{}(f);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -94,6 +94,41 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
|
||||
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
|
||||
#elif 1
|
||||
// each thread hold 64 data
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t BPerBlock = 16;
|
||||
constexpr index_t KPerBlock = 128;
|
||||
constexpr index_t EPerBlock = 8;
|
||||
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 1, 4>;
|
||||
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 2, 16, 1>;
|
||||
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
|
||||
using InBlockCopySrcAccessOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
|
||||
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
|
||||
|
||||
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
|
||||
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4;
|
||||
|
||||
using WeiBlockCopySubLengths_E_K = Sequence<2, 2>;
|
||||
using WeiBlockCopyClusterLengths_E_K = Sequence<4, 64>;
|
||||
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
|
||||
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 2;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 2;
|
||||
#elif 0
|
||||
// each thread hold 32 data
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
@@ -9,14 +9,14 @@
|
||||
#include "conv_common.hpp"
|
||||
#include "host_conv.hpp"
|
||||
#include "device_convolution_direct_v2_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp"
|
||||
#include "device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw.hpp"
|
||||
#include "device_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hpp"
|
||||
#include "device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp"
|
||||
//#include "device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp"
|
||||
//#include "device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw.hpp"
|
||||
//#include "device_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hpp"
|
||||
//#include "device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp"
|
||||
#include "device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_convolution_implicit_gemm_v4r3_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
|
||||
//#include "device_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw.hpp"
|
||||
//#include "device_convolution_implicit_gemm_v4r3_nchw_kcyx_nkhw.hpp"
|
||||
//#include "device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
|
||||
|
||||
struct GeneratorTensor_1
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user