mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Added Multi_ABD support into Gemm and GroupedGemmFixedNK (#978)
* added an example grouped_gemm_multi_abd * fixed ci * add setElementwiseOp * changed API * clean code: add multiA into example * fixed v7r2 copy * add transpose * clean * fixed vector_load check * Update example/15_grouped_gemm/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> * Update example/15_grouped_gemm/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> * Update example/15_grouped_gemm/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> * Update include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> * Update include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> * Update include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> * Update include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> * Update include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> * Update include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> * Update include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> * Update include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> * Update include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> * Update include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> * add reduce * testing * add example_b16_i8 * refactor example * clean * add mpading * disable reduce for kbatch = 1 * seperate reduce device op * add reduce op * add guard for workspace_size * add instances * format * fixed * add client example * add a colmajor * add instances * Update cmake-ck-dev.sh * Update profile_gemm_splitk.cpp * Update gridwise_gemm_xdlops_v2r4r2.hpp * format * Update profile_gemm_splitk.cpp * fixed * fixed * adjust test * adjust precision loss * adjust test * fixed * add bf16_i8 scale bias * fixed scale * fixed scale elementwise_op * revert contraction deviceop changes * fixed * Add AddFastGelu * Revert "Merge branch 'jizhan/gemm_splitk_reduce' into grouped_gemm_multi_abd_fixed_nk_example" This reverts commit3b5d001efd, reversing changes made to943199a991. * add Scales into elementwise * add gemm_multi_abd client example * add client examples * add rcr and crr * add grouped gemm client example * add grouped gemm client example * add instance for rcr crr * format * fixed * fixed cmake * fixed * fixed client_example * format * fixed contraction isSupport * Update include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> * Update device_reduce_threadwise.hpp * clean * Fixes * Fix example --------- Co-authored-by: Jing Zhang <jizha@amd.com> Co-authored-by: Bartłomiej Kocot <barkocot@amd.com>
This commit is contained in:
@@ -10,38 +10,9 @@
|
||||
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory
|
||||
// and sometimes useless instructions:
|
||||
// 1. Don't save a reference to tensor descriptor in class, pass in tensor descriptor as argument
|
||||
// instead
|
||||
// 2. Don't construct a new tensor coordinate everytime when using it, update and reuse the same
|
||||
// tensor coordinate instead
|
||||
// 3. Don't use a pointer to VGPR buffer, use vector instead
|
||||
|
||||
namespace detail {
|
||||
// TODO: How to fix this? It uses an struct instead of lambda because lambda
|
||||
// doesn't have constructor
|
||||
template <index_t VectorDim, index_t ScalarPerVector>
|
||||
struct lambda_scalar_per_access
|
||||
{
|
||||
__host__ __device__ constexpr auto operator()(index_t i) const
|
||||
{
|
||||
return (i == VectorDim) ? ScalarPerVector : 1;
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t VectorDim>
|
||||
struct lambda_scalar_step_in_vector
|
||||
{
|
||||
__host__ __device__ constexpr auto operator()(index_t i) const
|
||||
{
|
||||
return (i == VectorDim) ? 1 : 0;
|
||||
}
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
// Assume:
|
||||
// 1. src:
|
||||
// 1. SrcDesc is known at compile-time
|
||||
|
||||
@@ -0,0 +1,67 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#pragma once
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory
|
||||
// and sometimes useless instructions:
|
||||
// 1. Don't save a reference to tensor descriptor in class, pass in tensor descriptor as argument
|
||||
// instead
|
||||
// 2. Don't construct a new tensor coordinate everytime when using it, update and reuse the same
|
||||
// tensor coordinate instead
|
||||
// 3. Don't use a pointer to VGPR buffer, use vector instead
|
||||
|
||||
namespace detail {
|
||||
// TODO: How to fix this? It uses an struct instead of lambda because lambda
|
||||
// doesn't have constructor
|
||||
template <index_t VectorDim, index_t ScalarPerVector>
|
||||
struct lambda_scalar_per_access
|
||||
{
|
||||
__host__ __device__ constexpr auto operator()(index_t i) const
|
||||
{
|
||||
return (i == VectorDim) ? ScalarPerVector : 1;
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t VectorDim>
|
||||
struct lambda_scalar_step_in_vector
|
||||
{
|
||||
__host__ __device__ constexpr auto operator()(index_t i) const
|
||||
{
|
||||
return (i == VectorDim) ? 1 : 0;
|
||||
}
|
||||
};
|
||||
|
||||
// TODO: How to fix this? It uses an struct instead of lambda because lambda
|
||||
// doesn't have constructor
|
||||
template <index_t SrcVectorDim,
|
||||
index_t SrcScalarPerVector,
|
||||
index_t DstVectorDim,
|
||||
index_t DstScalarPerVector>
|
||||
struct lambda_scalar_per_access_for_src_and_dst
|
||||
{
|
||||
__host__ __device__ constexpr auto operator()(index_t i) const
|
||||
{
|
||||
if(i == SrcVectorDim && i == DstVectorDim)
|
||||
{
|
||||
return math::lcm(SrcScalarPerVector, DstScalarPerVector);
|
||||
}
|
||||
else if(i == SrcVectorDim)
|
||||
{
|
||||
return SrcScalarPerVector;
|
||||
}
|
||||
else if(i == DstVectorDim)
|
||||
{
|
||||
return DstScalarPerVector;
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
} // namespace ck
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -7,44 +7,13 @@
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
#include "ck/tensor/static_tensor.hpp"
|
||||
#include "ck/utility/is_detected.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
namespace detail {
|
||||
// TODO: How to fix this? It uses an struct instead of lambda because lambda
|
||||
// doesn't have constructor
|
||||
template <index_t SrcVectorDim,
|
||||
index_t SrcScalarPerVector,
|
||||
index_t DstVectorDim,
|
||||
index_t DstScalarPerVector>
|
||||
struct lambda_scalar_per_access_for_src_and_dst
|
||||
{
|
||||
__host__ __device__ constexpr auto operator()(index_t i) const
|
||||
{
|
||||
if(i == SrcVectorDim && i == DstVectorDim)
|
||||
{
|
||||
return math::lcm(SrcScalarPerVector, DstScalarPerVector);
|
||||
}
|
||||
else if(i == SrcVectorDim)
|
||||
{
|
||||
return SrcScalarPerVector;
|
||||
}
|
||||
else if(i == DstVectorDim)
|
||||
{
|
||||
return DstScalarPerVector;
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// Assume:
|
||||
// 1. src_desc and dst_desc are not known at compile-time
|
||||
// 2. SrcBuffer and DstBuffer are DynamicBuffer
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -8,9 +8,11 @@
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_description/tensor_space_filling_curve.hpp"
|
||||
#include "ck/utility/is_detected.hpp"
|
||||
#include "ck/tensor/static_tensor.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Thread-level multi-source, multi-destination tensor slice data movement
|
||||
// Assume:
|
||||
// 1. All sources and destinations are DynamicBuffer
|
||||
@@ -70,16 +72,18 @@ struct ThreadwiseTensorSliceTransfer_v7r2
|
||||
static constexpr auto src_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
using SrcSpaceFillingCurve = SpaceFillingCurve<SliceLengths,
|
||||
SrcDimAccessOrder,
|
||||
remove_cv_t<decltype(src_scalar_per_access)>>;
|
||||
|
||||
static constexpr auto dst_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
using SrcSpaceFillingCurve = SpaceFillingCurve<SliceLengths,
|
||||
SrcDimAccessOrder,
|
||||
remove_cv_t<decltype(src_scalar_per_access)>,
|
||||
false>;
|
||||
|
||||
using DstSpaceFillingCurve = SpaceFillingCurve<SliceLengths,
|
||||
DstDimAccessOrder,
|
||||
remove_cv_t<decltype(dst_scalar_per_access)>>;
|
||||
remove_cv_t<decltype(dst_scalar_per_access)>,
|
||||
false>;
|
||||
|
||||
__device__ constexpr ThreadwiseTensorSliceTransfer_v7r2(
|
||||
const SrcDescs& src_descs,
|
||||
@@ -139,9 +143,9 @@ struct ThreadwiseTensorSliceTransfer_v7r2
|
||||
__device__ void RunRead(const SrcDescs& src_descs, const SrcBuffers& src_bufs)
|
||||
{
|
||||
// loop over space-filling curve
|
||||
static_for<0, num_access, 1>{}([&](auto iAccess) {
|
||||
static_for<0, src_num_access, 1>{}([&](auto iAccess) {
|
||||
auto src_vectors = generate_vectors<SrcDatas, SrcScalarPerVector>();
|
||||
auto dst_vectors = generate_vectors<DstDatas, DstScalarPerVector>();
|
||||
auto elm_vectors = generate_vectors<DstDatas, SrcScalarPerVector>();
|
||||
|
||||
// copy data from src_bufs into src_vectors
|
||||
static_for<0, nSrc, 1>{}([&](auto i) {
|
||||
@@ -199,7 +203,7 @@ struct ThreadwiseTensorSliceTransfer_v7r2
|
||||
|
||||
using elem_op_vec_t = typename vector_type<DstData, elem_op_vec_len>::type;
|
||||
|
||||
return dst_vectors(iDst).template AsType<elem_op_vec_t>()(i);
|
||||
return elm_vectors(iDst).template AsType<elem_op_vec_t>()(i);
|
||||
},
|
||||
Number<nDst>{});
|
||||
|
||||
@@ -214,10 +218,10 @@ struct ThreadwiseTensorSliceTransfer_v7r2
|
||||
unpack2(element_op_, dst_data_refs, src_data_refs);
|
||||
});
|
||||
|
||||
dst_vectors_tuple_(iAccess) = dst_vectors;
|
||||
elm_vectors_tuple_(iAccess) = elm_vectors;
|
||||
|
||||
// move coordinate
|
||||
if constexpr(iAccess.value != num_access - 1)
|
||||
if constexpr(iAccess.value != src_num_access - 1)
|
||||
{
|
||||
constexpr auto forward_step = SrcSpaceFillingCurve::GetForwardStep(iAccess);
|
||||
|
||||
@@ -241,15 +245,113 @@ struct ThreadwiseTensorSliceTransfer_v7r2
|
||||
});
|
||||
}
|
||||
|
||||
__device__ void TransposeFromElmToDst()
|
||||
{
|
||||
using DstData = remove_cvref_t<decltype(DstDatas{}[I0])>;
|
||||
|
||||
using SrcThreadScratch =
|
||||
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
|
||||
DstData,
|
||||
SrcScalarPerVector,
|
||||
decltype(GetSrcThreadScratchDescriptor()),
|
||||
true>;
|
||||
|
||||
using DstThreadScratch =
|
||||
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
|
||||
DstData,
|
||||
DstScalarPerVector,
|
||||
decltype(GetDstThreadScratchDescriptor()),
|
||||
true>;
|
||||
|
||||
SrcThreadScratch elm_thread_scratch_;
|
||||
DstThreadScratch dst_thread_scratch_;
|
||||
|
||||
elm_thread_scratch_.data_ =
|
||||
bit_cast<decltype(elm_thread_scratch_.data_)>(elm_vectors_tuple_);
|
||||
|
||||
if constexpr(SrcVectorDim != DstVectorDim &&
|
||||
((is_same<half_t, remove_cvref_t<DstData>>::value &&
|
||||
SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0) ||
|
||||
(is_same<int8_t, remove_cvref_t<DstData>>::value &&
|
||||
SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0)))
|
||||
{
|
||||
// each transpose does
|
||||
// DstScalarPerVector # of src vectors in src_thread_scratch_
|
||||
// SrcScalarPerVector # of dst vectors in dst_thread_scratch_
|
||||
constexpr index_t num_src_vector = Number<DstScalarPerVector>{};
|
||||
constexpr index_t num_dst_vector = Number<SrcScalarPerVector>{};
|
||||
|
||||
// Assume SrcVectorDim is not the same as DstVectorDim, so we do transpose
|
||||
// TODO: make this logic generic for all scenario
|
||||
|
||||
constexpr auto src_scalar_step_in_vector = generate_sequence(
|
||||
detail::lambda_scalar_step_in_vector<SrcVectorDim>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto dst_scalar_step_in_vector = generate_sequence(
|
||||
detail::lambda_scalar_step_in_vector<DstVectorDim>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access_for_src_and_dst<SrcVectorDim,
|
||||
SrcScalarPerVector,
|
||||
DstVectorDim,
|
||||
DstScalarPerVector>{},
|
||||
Number<nDim>{});
|
||||
|
||||
constexpr auto access_lengths = SliceLengths{} / scalar_per_access;
|
||||
|
||||
static_ford<decltype(access_lengths)>{}([&](auto access_idx) {
|
||||
constexpr auto data_idx = access_idx * scalar_per_access;
|
||||
|
||||
constexpr auto data_idx_seq = generate_sequence_v2(
|
||||
[&](auto i) { return Number<data_idx[i]>{}; }, Number<nDim>{});
|
||||
|
||||
using src_vector_t = vector_type_maker_t<DstData, SrcScalarPerVector>;
|
||||
using dst_vector_t = vector_type_maker_t<DstData, DstScalarPerVector>;
|
||||
|
||||
// get DstScalarPerVector # of read-only references to src vectors from
|
||||
// src_thread_scratch_
|
||||
const auto src_vector_refs = generate_tie(
|
||||
[&](auto i) -> const src_vector_t& {
|
||||
// i increment corresponds to movement in DstVectorDim
|
||||
return elm_thread_scratch_.GetVectorTypeReference(
|
||||
data_idx_seq + i * dst_scalar_step_in_vector);
|
||||
},
|
||||
Number<num_src_vector>{});
|
||||
|
||||
// get SrcScalarPerVector # of references to dst vectors from dst_thread_scratch_
|
||||
auto dst_vector_refs = generate_tie(
|
||||
[&](auto i) -> dst_vector_t& {
|
||||
// i increment corresponds to movement in SrcVectorDim
|
||||
return dst_thread_scratch_.GetVectorTypeReference(
|
||||
data_idx_seq + i * src_scalar_step_in_vector);
|
||||
},
|
||||
Number<num_dst_vector>{});
|
||||
|
||||
// do data transpose
|
||||
transpose_vectors<DstData, DstScalarPerVector, SrcScalarPerVector>{}(
|
||||
src_vector_refs, dst_vector_refs);
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_ford<SliceLengths>{}(
|
||||
[&](auto idx) { dst_thread_scratch_(idx) = elm_thread_scratch_[idx]; });
|
||||
}
|
||||
|
||||
dst_vectors_tuple_ = bit_cast<decltype(dst_vectors_tuple_)>(dst_thread_scratch_.data_);
|
||||
}
|
||||
|
||||
// DstDescs: Tuple<const DstDesc0&, const DstDesc1&, ...>
|
||||
// DstBuffers: Tuple<const DstBuffer0&, const DstBuffer1&, ...>
|
||||
template <typename DstBuffers,
|
||||
enable_if_t<DstDescs::Size() == DstBuffers::Size(), bool> = false>
|
||||
enable_if_t<DstDescs::Size() == 1 && DstBuffers::Size() == 1, bool> = false>
|
||||
__device__ void RunWrite(const DstDescs& dst_descs, DstBuffers dst_bufs)
|
||||
{
|
||||
TransposeFromElmToDst();
|
||||
|
||||
// loop over space-filling curve
|
||||
static_for<0, num_access, 1>{}([&](auto iAccess) {
|
||||
auto dst_vectors = dst_vectors_tuple_[iAccess];
|
||||
static_for<0, dst_num_access, 1>{}([&](auto iAccess) {
|
||||
auto dst_vectors = dst_vectors_tuple_[Number<iAccess>{}];
|
||||
|
||||
// copy data from buf_vectors into dst_bufs
|
||||
static_for<0, nDst, 1>{}([&](auto i) {
|
||||
@@ -269,7 +371,7 @@ struct ThreadwiseTensorSliceTransfer_v7r2
|
||||
});
|
||||
|
||||
// move coordinate
|
||||
if constexpr(iAccess.value != num_access - 1)
|
||||
if constexpr(iAccess.value != dst_num_access - 1)
|
||||
{
|
||||
constexpr auto forward_step = DstSpaceFillingCurve::GetForwardStep(iAccess);
|
||||
|
||||
@@ -312,28 +414,126 @@ struct ThreadwiseTensorSliceTransfer_v7r2
|
||||
|
||||
__device__ static constexpr auto GetSrcCoordinateResetStep()
|
||||
{
|
||||
if constexpr(num_access == 0)
|
||||
if constexpr(src_num_access == 0)
|
||||
{
|
||||
return typename SrcSpaceFillingCurve::Index{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return SrcSpaceFillingCurve::GetStepBetween(Number<num_access - 1>{}, Number<0>{});
|
||||
return SrcSpaceFillingCurve::GetStepBetween(Number<src_num_access - 1>{}, Number<0>{});
|
||||
}
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetDstCoordinateResetStep()
|
||||
{
|
||||
if constexpr(num_access == 0)
|
||||
if constexpr(dst_num_access == 0)
|
||||
{
|
||||
return typename DstSpaceFillingCurve::Index{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return DstSpaceFillingCurve::GetStepBetween(Number<num_access - 1>{}, Number<0>{});
|
||||
return DstSpaceFillingCurve::GetStepBetween(Number<dst_num_access - 1>{}, Number<0>{});
|
||||
}
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetSrcThreadScratchDescriptor()
|
||||
{
|
||||
// constexpr auto src_scalar_per_access = generate_sequence(
|
||||
// detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
|
||||
|
||||
constexpr auto src_access_lengths_and_vector_length = container_push_back(
|
||||
sequence_to_tuple_of_number(src_access_lengths), Number<SrcScalarPerVector>{});
|
||||
|
||||
// 1st stage of transforms
|
||||
constexpr auto desc0 =
|
||||
make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length);
|
||||
|
||||
// 2nd stage of transforms
|
||||
constexpr auto transforms = generate_tuple(
|
||||
[&](auto i) {
|
||||
if constexpr(i == SrcVectorDim)
|
||||
{
|
||||
return make_merge_transform_v3_division_mod(
|
||||
make_tuple(src_access_lengths_and_vector_length[i],
|
||||
src_access_lengths_and_vector_length[Number<nDim>{}]));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_pass_through_transform(src_access_lengths_and_vector_length[i]);
|
||||
}
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
constexpr auto low_dim_idss = generate_tuple(
|
||||
[&](auto i) {
|
||||
if constexpr(i == SrcVectorDim)
|
||||
{
|
||||
return Sequence<i.value, nDim>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return Sequence<i.value>{};
|
||||
}
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
constexpr auto up_dim_idss =
|
||||
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
|
||||
|
||||
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetDstThreadScratchDescriptor()
|
||||
{
|
||||
// 1st stage of transforms
|
||||
// constexpr auto dst_scalar_per_access = generate_sequence(
|
||||
// detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
|
||||
|
||||
constexpr auto dst_access_lengths_and_vector_length = container_push_back(
|
||||
sequence_to_tuple_of_number(dst_access_lengths), Number<DstScalarPerVector>{});
|
||||
|
||||
constexpr auto desc0 =
|
||||
make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length);
|
||||
|
||||
// 2nd stage of transforms
|
||||
constexpr auto transforms = generate_tuple(
|
||||
[&](auto i) {
|
||||
if constexpr(i == DstVectorDim)
|
||||
{
|
||||
return make_merge_transform_v3_division_mod(
|
||||
make_tuple(dst_access_lengths_and_vector_length[i],
|
||||
dst_access_lengths_and_vector_length[Number<nDim>{}]));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_pass_through_transform(dst_access_lengths_and_vector_length[i]);
|
||||
}
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
constexpr auto low_dim_idss = generate_tuple(
|
||||
[&](auto i) {
|
||||
if constexpr(i == DstVectorDim)
|
||||
{
|
||||
return Sequence<i.value, nDim>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return Sequence<i.value>{};
|
||||
}
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
constexpr auto up_dim_idss =
|
||||
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
|
||||
|
||||
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
|
||||
}
|
||||
|
||||
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
|
||||
template <index_t ISrc>
|
||||
__device__ void MoveSrcSliceWindow(const SrcDescs& src_descs,
|
||||
@@ -372,11 +572,14 @@ struct ThreadwiseTensorSliceTransfer_v7r2
|
||||
|
||||
private:
|
||||
using SrcVectorsType = decltype(generate_vectors<SrcDatas, SrcScalarPerVector>());
|
||||
using ElmVectorsType = decltype(generate_vectors<DstDatas, SrcScalarPerVector>());
|
||||
using DstVectorsType = decltype(generate_vectors<DstDatas, DstScalarPerVector>());
|
||||
|
||||
static constexpr auto num_access = SrcSpaceFillingCurve::GetNumOfAccess();
|
||||
static constexpr auto src_num_access = SrcSpaceFillingCurve::GetNumOfAccess();
|
||||
static constexpr auto dst_num_access = DstSpaceFillingCurve::GetNumOfAccess();
|
||||
|
||||
StaticallyIndexedArray<DstVectorsType, num_access> dst_vectors_tuple_;
|
||||
StaticallyIndexedArray<ElmVectorsType, src_num_access> elm_vectors_tuple_;
|
||||
StaticallyIndexedArray<DstVectorsType, dst_num_access> dst_vectors_tuple_;
|
||||
|
||||
SrcCoords src_coords_;
|
||||
DstCoords dst_coords_;
|
||||
|
||||
Reference in New Issue
Block a user