mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
use offsets in transfer ok
This commit is contained in:
@@ -7,7 +7,7 @@
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_description/cluster_descriptor.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_gather.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -41,14 +41,15 @@ template <typename ThreadGroup,
|
||||
index_t DstScalarStrideInVector,
|
||||
bool ThreadTransferSrcResetCoordinateAfterRun,
|
||||
bool ThreadTransferDstResetCoordinateAfterRun,
|
||||
index_t GatherDim = 1,
|
||||
index_t NumThreadScratch = 1>
|
||||
struct ThreadGroupTensorSliceTransfer_v4r1_mod8
|
||||
{
|
||||
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
|
||||
|
||||
static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{};
|
||||
|
||||
static constexpr index_t gather_num = thread_slice_lengths.At(Number<GatherDim>{});
|
||||
using Index = MultiIndex<nDim>;
|
||||
// using GatherIndex = MultiIndex<gather_num>;
|
||||
|
||||
__device__ constexpr ThreadGroupTensorSliceTransfer_v4r1_mod8(
|
||||
const SrcDesc& src_desc,
|
||||
@@ -56,13 +57,15 @@ struct ThreadGroupTensorSliceTransfer_v4r1_mod8
|
||||
const SrcElementwiseOperation& src_element_op,
|
||||
const DstDesc& dst_desc,
|
||||
const Index& dst_block_slice_origin,
|
||||
const DstElementwiseOperation& dst_element_op)
|
||||
const DstElementwiseOperation& dst_element_op,
|
||||
const StaticallyIndexedArray<index_t, gather_num> &gather_offsets)
|
||||
: threadwise_transfer_(src_desc,
|
||||
make_zero_multi_index<nDim>(),
|
||||
src_element_op,
|
||||
dst_desc,
|
||||
make_zero_multi_index<nDim>(),
|
||||
dst_element_op)
|
||||
dst_element_op,
|
||||
gather_offsets)
|
||||
|
||||
{
|
||||
static_assert(nDim == remove_cvref_t<SrcDesc>::GetNumOfDimension() &&
|
||||
@@ -173,7 +176,7 @@ struct ThreadGroupTensorSliceTransfer_v4r1_mod8
|
||||
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
using ThreadwiseTransfer =
|
||||
ThreadwiseTensorSliceTransfer_v3r1<decltype(thread_slice_lengths),
|
||||
ThreadwiseTensorSliceTransfer_v3r1_gather<decltype(thread_slice_lengths),
|
||||
SrcElementwiseOperation,
|
||||
DstElementwiseOperation,
|
||||
DstInMemOp,
|
||||
@@ -191,6 +194,7 @@ struct ThreadGroupTensorSliceTransfer_v4r1_mod8
|
||||
DstScalarStrideInVector,
|
||||
ThreadTransferSrcResetCoordinateAfterRun,
|
||||
ThreadTransferDstResetCoordinateAfterRun,
|
||||
GatherDim,
|
||||
NumThreadScratch>;
|
||||
|
||||
ThreadwiseTransfer threadwise_transfer_;
|
||||
|
||||
@@ -1132,8 +1132,11 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
|
||||
constexpr auto MLoadRepeats = MPerBlock / MLoadThreads;
|
||||
static_assert(MLoadRepeats == 1, "only support 1 line per thread now!");
|
||||
const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / KLoadThreads;
|
||||
|
||||
index_t token_offset = p_sorted_token_ids[token_pos];
|
||||
StaticallyIndexedArray<index_t, MLoadRepeats> token_offsets; //= p_sorted_token_ids[token_pos];
|
||||
static_for<0, MLoadRepeats, 1>{}([&](auto m0) {
|
||||
token_offsets(m0) = p_sorted_token_ids[token_pos + MLoadThreads * m0] * problem.K;
|
||||
});
|
||||
printf("threadIdx.x %d off %d\n", threadIdx.x, token_offsets(I0));
|
||||
const index_t m_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
|
||||
const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K);
|
||||
@@ -1183,13 +1186,15 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
|
||||
1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
true,
|
||||
1,
|
||||
BlockwiseGemmPipe::GlobalBufferNum>(
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
make_multi_index(0, token_offset, 0),
|
||||
make_multi_index(0, 0, 0),
|
||||
a_element_op,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
make_multi_index(0, 0, 0),
|
||||
ck::tensor_operation::element_wise::PassThrough{});
|
||||
ck::tensor_operation::element_wise::PassThrough{},
|
||||
token_offsets);
|
||||
|
||||
// Thread-wise copy
|
||||
// K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
|
||||
|
||||
@@ -0,0 +1,902 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
|
||||
#include "ck/tensor/static_tensor.hpp"
|
||||
#include "ck/utility/is_detected.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Assume:
|
||||
// 1. src_desc and dst_desc are not known at compile-time
|
||||
// 2. SrcBuffer and DstBuffer are DynamicBuffer
|
||||
// 3. src_slice_origin and dst_slice_origin are not known at compile-time,
|
||||
// 4. Use thread buffer
|
||||
template <typename SliceLengths,
|
||||
typename SrcElementwiseOperation,
|
||||
typename DstElementwiseOperation,
|
||||
InMemoryDataOperationEnum DstInMemOp,
|
||||
typename SrcData,
|
||||
typename DstData,
|
||||
typename SrcDesc,
|
||||
typename DstDesc,
|
||||
typename SrcDimAccessOrder,
|
||||
typename DstDimAccessOrder,
|
||||
index_t SrcVectorDim,
|
||||
index_t DstVectorDim,
|
||||
index_t SrcScalarPerVector,
|
||||
index_t DstScalarPerVector,
|
||||
index_t SrcScalarStrideInVector,
|
||||
index_t DstScalarStrideInVector,
|
||||
bool SrcResetCoordinateAfterRun, // control whether to move back src coordinate after each
|
||||
// RunRead(), will be fused with MoveSrcSliceWindow to
|
||||
// save addr computation
|
||||
bool DstResetCoordinateAfterRun, // control whether to move back dst coordinate after each
|
||||
// RunWrite(), will be fused with MoveDstSliceWindow to
|
||||
// save addr computation
|
||||
index_t GatherDim = 1,
|
||||
index_t NumThreadScratch = 1>
|
||||
struct ThreadwiseTensorSliceTransfer_v3r1_gather
|
||||
{
|
||||
static constexpr index_t nDim = SliceLengths::Size();
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
|
||||
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
|
||||
|
||||
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
|
||||
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr index_t gather_num = SliceLengths{}.At(Number<GatherDim>{});
|
||||
|
||||
__device__ constexpr ThreadwiseTensorSliceTransfer_v3r1_gather(
|
||||
const SrcDesc& src_desc,
|
||||
const Index& src_slice_origin,
|
||||
const SrcElementwiseOperation& src_element_op,
|
||||
const DstDesc& dst_desc,
|
||||
const Index& dst_slice_origin,
|
||||
const DstElementwiseOperation& dst_element_op,
|
||||
const StaticallyIndexedArray<index_t, gather_num> &gather_offsets)
|
||||
: src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)),
|
||||
dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)),
|
||||
src_element_op_(src_element_op),
|
||||
dst_element_op_(dst_element_op),
|
||||
gather_offsets_(gather_offsets)
|
||||
{
|
||||
}
|
||||
|
||||
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
|
||||
{
|
||||
src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx);
|
||||
}
|
||||
|
||||
__device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
|
||||
{
|
||||
dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx);
|
||||
}
|
||||
|
||||
template <typename SrcBuffer, index_t ThreadScratchId = 0>
|
||||
__device__ void RunRead(const SrcDesc& src_desc,
|
||||
const SrcBuffer& src_buf,
|
||||
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Global or
|
||||
SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Lds,
|
||||
"wrong!");
|
||||
|
||||
static_assert(
|
||||
is_same<remove_cvref_t<typename SrcBuffer::type>, remove_cvref_t<SrcData>>::value,
|
||||
"wrong! SrcBuffer and SrcData data type are inconsistent");
|
||||
|
||||
// scalar per access on each dim
|
||||
// TODO: don't use lambda_scalar_per_access
|
||||
constexpr auto src_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
|
||||
static_assert(SliceLengths::At(SrcVectorDim) % SrcScalarPerVector == 0,
|
||||
"SliceLengths[SrcVectorDim] must be divisible by SrcScalarPerVector");
|
||||
|
||||
constexpr auto src_dim_access_order = SrcDimAccessOrder{};
|
||||
constexpr auto ordered_gather_dim = src_dim_access_order[GatherDim];
|
||||
constexpr auto ordered_src_access_lengths =
|
||||
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
|
||||
|
||||
// make forward steps
|
||||
const auto src_forward_steps = generate_tuple(
|
||||
[&](auto i) {
|
||||
Index forward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0;
|
||||
});
|
||||
|
||||
return make_tensor_coordinate_step(src_desc, forward_step_idx);
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
// make backward steps
|
||||
const auto src_backward_steps = generate_tuple(
|
||||
[&](auto i) {
|
||||
Index backward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0;
|
||||
});
|
||||
|
||||
return make_tensor_coordinate_step(src_desc, backward_step_idx);
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
// loop over tensor and copy
|
||||
static_ford<decltype(ordered_src_access_lengths)>{}([&](auto ordered_src_access_idx) {
|
||||
// judge move forward or move backward
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep_;
|
||||
|
||||
forward_sweep_(I0) = true;
|
||||
|
||||
static_for<1, nDim, 1>{}([&](auto i) {
|
||||
index_t tmp = ordered_src_access_idx[I0];
|
||||
|
||||
static_for<1, i, 1>{}([&](auto j) {
|
||||
tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j];
|
||||
});
|
||||
|
||||
forward_sweep_(i) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep_;
|
||||
}();
|
||||
|
||||
// calculate src data index
|
||||
constexpr auto src_data_idx = [&]() {
|
||||
Index ordered_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i]
|
||||
: ordered_src_access_lengths[i] - 1 -
|
||||
ordered_src_access_idx[i];
|
||||
});
|
||||
|
||||
return container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
|
||||
src_scalar_per_access;
|
||||
}();
|
||||
|
||||
constexpr auto src_data_idx_seq = generate_sequence_v2(
|
||||
[&](auto i) { return Number<src_data_idx[i]>{}; }, Number<src_data_idx.Size()>{});
|
||||
|
||||
auto gather_offset = gather_offsets_(I0);//ordered_src_access_idx[Number<ordered_gather_dim>{}]);
|
||||
|
||||
// maintain a container record is_src_valid, waiting for RunWrite use.
|
||||
const bool is_src_valid =
|
||||
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_) && (gather_offset < 32*512);
|
||||
src_oob_thread_scratch_tuple_(thread_scratch_id)
|
||||
.template SetAsType<bool>(src_data_idx_seq, is_src_valid);
|
||||
|
||||
using src_vector_type = vector_type_maker_t<SrcData, SrcScalarPerVector>;
|
||||
using src_vector_t = typename src_vector_type::type;
|
||||
if(blockIdx.x+blockIdx.y==0)
|
||||
printf("tid %d off %d %d\n", threadIdx.x, src_coord_.GetOffset(), gather_offset );
|
||||
auto src_vector_container =
|
||||
src_vector_type{src_buf.template Get<src_vector_t>(src_coord_.GetOffset() + gather_offset, true)};
|
||||
|
||||
using dst_vector_type = vector_type_maker_t<DstData, SrcScalarPerVector>;
|
||||
using dst_vector_t = typename dst_vector_type::type;
|
||||
dst_vector_type op_r_v;
|
||||
|
||||
constexpr auto get_elem_op_vec_len = []() {
|
||||
if constexpr(is_detected<is_pack8_invocable_t, decltype(src_element_op_)>::value)
|
||||
{
|
||||
if constexpr(decltype(src_element_op_)::is_pack8_invocable)
|
||||
return math::min(8, SrcScalarPerVector);
|
||||
}
|
||||
if constexpr(is_detected<is_pack4_invocable_t, decltype(src_element_op_)>::value)
|
||||
{
|
||||
if constexpr(decltype(src_element_op_)::is_pack4_invocable)
|
||||
return math::min(4, SrcScalarPerVector);
|
||||
}
|
||||
if constexpr(is_detected<is_pack2_invocable_t, decltype(src_element_op_)>::value)
|
||||
{
|
||||
if constexpr(decltype(src_element_op_)::is_pack2_invocable)
|
||||
return math::min(2, SrcScalarPerVector);
|
||||
}
|
||||
return 1;
|
||||
};
|
||||
|
||||
constexpr index_t elem_op_vec_len = get_elem_op_vec_len();
|
||||
|
||||
using src_elem_op_vec_t = typename vector_type<SrcData, elem_op_vec_len>::type;
|
||||
using dst_elem_op_vec_t = typename vector_type<DstData, elem_op_vec_len>::type;
|
||||
|
||||
static_for<0, SrcScalarPerVector / elem_op_vec_len, 1>{}([&](auto idx) {
|
||||
// apply the src elementwise op and convert to DstData under the hood if needed
|
||||
src_element_op_(op_r_v.template AsType<dst_elem_op_vec_t>()(idx),
|
||||
src_vector_container.template AsType<src_elem_op_vec_t>()[idx]);
|
||||
});
|
||||
|
||||
// copy data from src_vector_container into src_thread_scratch_
|
||||
src_thread_scratch_tuple_(thread_scratch_id)
|
||||
.template SetAsType<dst_vector_t>(src_data_idx_seq,
|
||||
op_r_v.template AsType<dst_vector_t>()[I0]);
|
||||
|
||||
// if(1) {
|
||||
// using print_vec_t = typename vector_type<DstData, 1>::type;
|
||||
// static_for<0, SrcScalarPerVector, 1>{}([&](auto idx) {
|
||||
// printf("tid %d %f\n",threadIdx.x, type_convert<float>(src_vector_container.template AsType<print_vec_t>()[idx]));
|
||||
// });
|
||||
// }
|
||||
constexpr auto move_on_dim = [&]() constexpr
|
||||
{
|
||||
StaticallyIndexedArray<bool, nDim> move_on_dim_;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
move_on_dim_(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1;
|
||||
|
||||
static_for<i + 1, nDim, 1>{}([&](auto j) {
|
||||
move_on_dim_(i) &=
|
||||
ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1;
|
||||
});
|
||||
});
|
||||
|
||||
return move_on_dim_;
|
||||
}
|
||||
();
|
||||
|
||||
// move src coord
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
if constexpr(move_on_dim[i])
|
||||
{
|
||||
if constexpr(forward_sweep[i])
|
||||
{
|
||||
move_tensor_coordinate(
|
||||
src_desc, src_coord_, src_forward_steps[src_dim_access_order[i]]);
|
||||
}
|
||||
else
|
||||
{
|
||||
move_tensor_coordinate(
|
||||
src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]);
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// move src coordinate back to slice origin (or not)
|
||||
if constexpr(SrcResetCoordinateAfterRun)
|
||||
{
|
||||
const auto src_reset_step =
|
||||
make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep());
|
||||
|
||||
move_tensor_coordinate(src_desc, src_coord_, src_reset_step);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SeqIdx, index_t ThreadScratchId = 0>
|
||||
__device__ constexpr auto
|
||||
GetSrcThreadScratchIdx(Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
using vector_t = typename vector_type_maker<SrcData, SrcScalarPerVector>::type::type;
|
||||
return src_thread_scratch_tuple_(thread_scratch_id).template GetAsType<vector_t>(SeqIdx{});
|
||||
}
|
||||
|
||||
template <index_t ThreadScratchId>
|
||||
__device__ void
|
||||
TransferDataFromSrcThreadScratchToDstThreadScratch(Number<ThreadScratchId> thread_scratch_id)
|
||||
{
|
||||
#if !CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE
|
||||
static_ford<SliceLengths>{}([&](auto idx) {
|
||||
dst_thread_scratch_(idx) = src_thread_scratch_tuple_[thread_scratch_id][idx];
|
||||
});
|
||||
#else
|
||||
|
||||
// OOB Check
|
||||
constexpr auto src_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
|
||||
|
||||
constexpr auto src_dim_access_order = SrcDimAccessOrder{};
|
||||
|
||||
constexpr auto ordered_src_access_lengths =
|
||||
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
|
||||
|
||||
// loop over tensor and copy
|
||||
static_ford<decltype(ordered_src_access_lengths)>{}([&](auto ordered_src_access_idx) {
|
||||
// judge move forward or move backward
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep_;
|
||||
|
||||
forward_sweep_(I0) = true;
|
||||
|
||||
static_for<1, nDim, 1>{}([&](auto i) {
|
||||
index_t tmp = ordered_src_access_idx[I0];
|
||||
|
||||
static_for<1, i, 1>{}([&](auto j) {
|
||||
tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j];
|
||||
});
|
||||
|
||||
forward_sweep_(i) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep_;
|
||||
}();
|
||||
|
||||
// calculate src data index
|
||||
constexpr auto src_data_idx = [&]() {
|
||||
Index ordered_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i]
|
||||
: ordered_src_access_lengths[i] - 1 -
|
||||
ordered_src_access_idx[i];
|
||||
});
|
||||
|
||||
return container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
|
||||
src_scalar_per_access;
|
||||
}();
|
||||
|
||||
constexpr auto src_data_idx_seq = generate_sequence_v2(
|
||||
[&](auto i) { return Number<src_data_idx[i]>{}; }, Number<src_data_idx.Size()>{});
|
||||
|
||||
using vector_t = typename vector_type_maker<DstData, SrcScalarPerVector>::type::type;
|
||||
|
||||
auto op_r = src_thread_scratch_tuple_(thread_scratch_id)
|
||||
.template GetAsType<vector_t>(src_data_idx_seq);
|
||||
|
||||
const bool is_src_valid = src_oob_thread_scratch_tuple_(thread_scratch_id)
|
||||
.template GetAsType<bool>(src_data_idx_seq);
|
||||
|
||||
auto op_r_v = is_src_valid ? op_r : vector_t(0);
|
||||
|
||||
src_thread_scratch_tuple_(thread_scratch_id)
|
||||
.template SetAsType<vector_t>(src_data_idx_seq, op_r_v);
|
||||
});
|
||||
|
||||
// sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_
|
||||
// TODO make this logic more generic for more sub-dword datatype
|
||||
if constexpr(SrcVectorDim != DstVectorDim &&
|
||||
((is_same<half_t, remove_cvref_t<DstData>>::value &&
|
||||
SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0) ||
|
||||
(is_same<int8_t, remove_cvref_t<DstData>>::value &&
|
||||
SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0) ||
|
||||
(is_same<f8_t, remove_cvref_t<DstData>>::value &&
|
||||
SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0)))
|
||||
{
|
||||
// each transpose does
|
||||
// DstScalarPerVector # of src vectors in src_thread_scratch_
|
||||
// SrcScalarPerVector # of dst vectors in dst_thread_scratch_
|
||||
constexpr index_t num_src_vector = Number<DstScalarPerVector>{};
|
||||
constexpr index_t num_dst_vector = Number<SrcScalarPerVector>{};
|
||||
|
||||
// Assume SrcVectorDim is not the same as DstVectorDim, so we do transpose
|
||||
// TODO: make this logic generic for all scenario
|
||||
static_assert(SrcVectorDim != DstVectorDim, "wrong");
|
||||
|
||||
constexpr auto src_scalar_step_in_vector = generate_sequence(
|
||||
detail::lambda_scalar_step_in_vector<SrcVectorDim>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto dst_scalar_step_in_vector = generate_sequence(
|
||||
detail::lambda_scalar_step_in_vector<DstVectorDim>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access_for_src_and_dst<SrcVectorDim,
|
||||
SrcScalarPerVector,
|
||||
DstVectorDim,
|
||||
DstScalarPerVector>{},
|
||||
Number<nDim>{});
|
||||
|
||||
constexpr auto access_lengths = SliceLengths{} / scalar_per_access;
|
||||
|
||||
static_ford<decltype(access_lengths)>{}([&](auto access_idx) {
|
||||
constexpr auto data_idx = access_idx * scalar_per_access;
|
||||
|
||||
constexpr auto data_idx_seq = generate_sequence_v2(
|
||||
[&](auto i) { return Number<data_idx[i]>{}; }, Number<nDim>{});
|
||||
|
||||
using src_vector_t = vector_type_maker_t<DstData, SrcScalarPerVector>;
|
||||
using dst_vector_t = vector_type_maker_t<DstData, DstScalarPerVector>;
|
||||
|
||||
// get DstScalarPerVector # of read-only references to src vectors from
|
||||
// src_thread_scratch_
|
||||
const auto src_vector_refs = generate_tie(
|
||||
[&](auto i) -> const src_vector_t& {
|
||||
// i increment corresponds to movement in DstVectorDim
|
||||
return src_thread_scratch_tuple_[thread_scratch_id].GetVectorTypeReference(
|
||||
data_idx_seq + i * dst_scalar_step_in_vector);
|
||||
},
|
||||
Number<num_src_vector>{});
|
||||
|
||||
// get SrcScalarPerVector # of references to dst vectors from dst_thread_scratch_
|
||||
auto dst_vector_refs = generate_tie(
|
||||
[&](auto i) -> dst_vector_t& {
|
||||
// i increment corresponds to movement in SrcVectorDim
|
||||
return dst_thread_scratch_.GetVectorTypeReference(
|
||||
data_idx_seq + i * src_scalar_step_in_vector);
|
||||
},
|
||||
Number<num_dst_vector>{});
|
||||
|
||||
// do data transpose
|
||||
transpose_vectors<DstData, DstScalarPerVector, SrcScalarPerVector>{}(
|
||||
src_vector_refs, dst_vector_refs);
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_ford<SliceLengths>{}([&](auto idx) {
|
||||
dst_thread_scratch_(idx) = src_thread_scratch_tuple_[thread_scratch_id][idx];
|
||||
});
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename DstBuffer, index_t ThreadScratchId = 0>
|
||||
__device__ void RunWrite(const DstDesc& dst_desc,
|
||||
DstBuffer& dst_buf,
|
||||
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
// if there is transpose, it's done here
|
||||
// if there is oob check, it's done here
|
||||
// TODO move this elsewhere
|
||||
TransferDataFromSrcThreadScratchToDstThreadScratch(thread_scratch_id);
|
||||
|
||||
static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Global or
|
||||
DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds,
|
||||
"wrong!");
|
||||
|
||||
static_assert(
|
||||
is_same<remove_cvref_t<typename DstBuffer::type>, remove_cvref_t<DstData>>::value,
|
||||
"wrong! SrcBuffer or DstBuffer data type is wrong");
|
||||
|
||||
// src scalar per access on each dim
|
||||
// TODO: don't use this
|
||||
constexpr auto dst_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
|
||||
|
||||
constexpr auto dst_dim_access_order = DstDimAccessOrder{};
|
||||
|
||||
constexpr auto ordered_dst_access_lengths =
|
||||
container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
|
||||
|
||||
// make forward steps
|
||||
const auto dst_forward_steps = generate_tuple(
|
||||
[&](auto i) {
|
||||
Index forward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0;
|
||||
});
|
||||
|
||||
return make_tensor_coordinate_step(dst_desc, forward_step_idx);
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
// make backward steps
|
||||
const auto dst_backward_steps = generate_tuple(
|
||||
[&](auto i) {
|
||||
Index backward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0;
|
||||
});
|
||||
|
||||
return make_tensor_coordinate_step(dst_desc, backward_step_idx);
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
// loop over tensor and copy
|
||||
static_ford<decltype(ordered_dst_access_lengths)>{}([&](auto ordered_dst_access_idx) {
|
||||
// judge move forward or move backward
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep_;
|
||||
|
||||
forward_sweep_(I0) = true;
|
||||
|
||||
static_for<1, nDim, 1>{}([&](auto i) {
|
||||
index_t tmp = ordered_dst_access_idx[I0];
|
||||
|
||||
static_for<1, i, 1>{}([&](auto j) {
|
||||
tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j];
|
||||
});
|
||||
|
||||
forward_sweep_(i) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep_;
|
||||
}();
|
||||
|
||||
// calculate dst data index
|
||||
constexpr auto dst_data_idx = [&]() {
|
||||
Index ordered_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_idx[i]
|
||||
: ordered_dst_access_lengths[i] - 1 -
|
||||
ordered_dst_access_idx[i];
|
||||
});
|
||||
|
||||
return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) *
|
||||
dst_scalar_per_access;
|
||||
}();
|
||||
|
||||
constexpr auto dst_data_idx_seq = generate_sequence_v2(
|
||||
[&](auto i) { return Number<dst_data_idx[i]>{}; }, Number<dst_data_idx.Size()>{});
|
||||
|
||||
const bool is_dst_valid =
|
||||
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
|
||||
|
||||
using dst_vector_type = vector_type_maker_t<DstData, DstScalarPerVector>;
|
||||
using dst_vector_t = typename dst_vector_type::type;
|
||||
|
||||
// copy data from dst_thread_scratch_ into dst_vector_container
|
||||
auto dst_vector_container = dst_vector_type{
|
||||
dst_thread_scratch_.template GetAsType<dst_vector_t>(dst_data_idx_seq)};
|
||||
|
||||
static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
|
||||
DstData dst_v;
|
||||
|
||||
// apply DstElementwiseOperation
|
||||
dst_element_op_(dst_v, dst_vector_container.template AsType<DstData>()[i]);
|
||||
|
||||
dst_vector_container.template AsType<DstData>()(i) = dst_v;
|
||||
});
|
||||
|
||||
// copy data from dst_vector_container to dst_buf
|
||||
dst_buf.template Set<dst_vector_t>(
|
||||
dst_coord_.GetOffset(),
|
||||
is_dst_valid,
|
||||
dst_vector_container.template AsType<dst_vector_t>()[I0]);
|
||||
|
||||
// if(1) {
|
||||
// using print_vec_t = typename vector_type<DstData, 1>::type;
|
||||
// static_for<0, DstScalarPerVector, 1>{}([&](auto idx) {
|
||||
// printf("tid %d off %d valid %d val %f\n",threadIdx.x, dst_coord_.GetOffset(), is_dst_valid, type_convert<float>(dst_vector_container.template AsType<print_vec_t>()[idx]));
|
||||
// });
|
||||
// }
|
||||
constexpr auto move_on_dim = [&]() constexpr
|
||||
{
|
||||
StaticallyIndexedArray<bool, nDim> move_on_dim_;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
move_on_dim_(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1;
|
||||
|
||||
static_for<i + 1, nDim, 1>{}([&](auto j) {
|
||||
move_on_dim_(i) &=
|
||||
ordered_dst_access_idx[j] == ordered_dst_access_lengths[j] - 1;
|
||||
});
|
||||
});
|
||||
|
||||
return move_on_dim_;
|
||||
}
|
||||
();
|
||||
|
||||
// move dst coord
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
if constexpr(move_on_dim[i])
|
||||
{
|
||||
if constexpr(forward_sweep[i])
|
||||
{
|
||||
move_tensor_coordinate(
|
||||
dst_desc, dst_coord_, dst_forward_steps[dst_dim_access_order[i]]);
|
||||
}
|
||||
else
|
||||
{
|
||||
move_tensor_coordinate(
|
||||
dst_desc, dst_coord_, dst_backward_steps[dst_dim_access_order[i]]);
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// move dst coordinate back to slice origin (or not)
|
||||
if constexpr(DstResetCoordinateAfterRun)
|
||||
{
|
||||
const auto dst_reset_step =
|
||||
make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep());
|
||||
|
||||
move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetSrcCoordinateResetStep()
|
||||
{
|
||||
// scalar per access on each dim
|
||||
// TODO: don't use lambda_scalar_per_access
|
||||
constexpr auto src_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
|
||||
|
||||
constexpr auto src_dim_access_order = SrcDimAccessOrder{};
|
||||
|
||||
constexpr auto ordered_src_access_lengths =
|
||||
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
|
||||
|
||||
// judge move forward or move backward during the last iteration
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep_;
|
||||
|
||||
forward_sweep_(I0) = true;
|
||||
|
||||
static_for<1, nDim, 1>{}([&](auto i) {
|
||||
index_t tmp = ordered_src_access_lengths[I0] - 1;
|
||||
|
||||
static_for<1, i, 1>{}([&](auto j) {
|
||||
tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1;
|
||||
});
|
||||
|
||||
forward_sweep_(i) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep_;
|
||||
}();
|
||||
|
||||
// calculate src data index after last iteration in RunRead(), if it has not being reset by
|
||||
// RunRead()
|
||||
constexpr auto src_data_idx = [&]() {
|
||||
Index ordered_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0;
|
||||
});
|
||||
|
||||
return container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
|
||||
src_scalar_per_access;
|
||||
}();
|
||||
|
||||
//
|
||||
constexpr auto reset_src_data_step = [&]() {
|
||||
Index reset_src_data_step_;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = -src_data_idx[i]; });
|
||||
|
||||
return reset_src_data_step_;
|
||||
}();
|
||||
|
||||
return reset_src_data_step;
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetDstCoordinateResetStep()
|
||||
{
|
||||
// scalar per access on each dim
|
||||
// TODO: don't use lambda_scalar_per_access
|
||||
constexpr auto dst_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
|
||||
|
||||
constexpr auto dst_dim_access_order = DstDimAccessOrder{};
|
||||
|
||||
constexpr auto ordered_dst_access_lengths =
|
||||
container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
|
||||
|
||||
// judge move forward or move backward during the last iteration
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep_;
|
||||
|
||||
forward_sweep_(I0) = true;
|
||||
|
||||
static_for<1, nDim, 1>{}([&](auto i) {
|
||||
index_t tmp = ordered_dst_access_lengths[I0] - 1;
|
||||
|
||||
static_for<1, i, 1>{}([&](auto j) {
|
||||
tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1;
|
||||
});
|
||||
|
||||
forward_sweep_(i) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep_;
|
||||
}();
|
||||
|
||||
// calculate dst data index after last iteration in RunWrite(), if it has not being reset by
|
||||
// RunWrite()
|
||||
constexpr auto dst_data_idx = [&]() {
|
||||
Index ordered_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0;
|
||||
});
|
||||
|
||||
return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) *
|
||||
dst_scalar_per_access;
|
||||
}();
|
||||
|
||||
//
|
||||
constexpr auto reset_dst_data_step = [&]() {
|
||||
Index reset_dst_data_step_;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; });
|
||||
|
||||
return reset_dst_data_step_;
|
||||
}();
|
||||
|
||||
return reset_dst_data_step;
|
||||
}
|
||||
|
||||
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
|
||||
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc,
|
||||
const Index& src_slice_origin_step_idx)
|
||||
{
|
||||
// if src coord was not reset by RunRead(), then need to adjust the step here
|
||||
const auto adjusted_step_idx =
|
||||
SrcResetCoordinateAfterRun ? src_slice_origin_step_idx
|
||||
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
|
||||
|
||||
// is it OK to construct a new step every time?
|
||||
const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx);
|
||||
|
||||
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
|
||||
}
|
||||
|
||||
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
|
||||
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc,
|
||||
const Index& dst_slice_origin_step_idx)
|
||||
{
|
||||
// if dst coord was not reset by RunWrite(), then need to adjust the step here
|
||||
const auto adjusted_step_idx =
|
||||
DstResetCoordinateAfterRun ? dst_slice_origin_step_idx
|
||||
: dst_slice_origin_step_idx + GetDstCoordinateResetStep();
|
||||
|
||||
// is it OK to construct a new step every time?
|
||||
const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx);
|
||||
|
||||
move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetSrcThreadScratchDescriptor()
|
||||
{
|
||||
constexpr auto src_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
|
||||
|
||||
constexpr auto src_access_lengths_and_vector_length = container_push_back(
|
||||
sequence_to_tuple_of_number(src_access_lengths), Number<SrcScalarPerVector>{});
|
||||
|
||||
// 1st stage of transforms
|
||||
constexpr auto desc0 =
|
||||
make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length);
|
||||
|
||||
// 2nd stage of transforms
|
||||
constexpr auto transforms = generate_tuple(
|
||||
[&](auto i) {
|
||||
if constexpr(i == SrcVectorDim)
|
||||
{
|
||||
return make_merge_transform_v3_division_mod(
|
||||
make_tuple(src_access_lengths_and_vector_length[i],
|
||||
src_access_lengths_and_vector_length[Number<nDim>{}]));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_pass_through_transform(src_access_lengths_and_vector_length[i]);
|
||||
}
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
constexpr auto low_dim_idss = generate_tuple(
|
||||
[&](auto i) {
|
||||
if constexpr(i == SrcVectorDim)
|
||||
{
|
||||
return Sequence<i.value, nDim>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return Sequence<i.value>{};
|
||||
}
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
constexpr auto up_dim_idss =
|
||||
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
|
||||
|
||||
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetSrcOOBThreadScratchDescriptor()
|
||||
{
|
||||
constexpr auto src_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
|
||||
|
||||
return make_naive_tensor_descriptor_packed(src_access_lengths);
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetDstThreadScratchDescriptor()
|
||||
{
|
||||
// 1st stage of transforms
|
||||
constexpr auto dst_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
|
||||
|
||||
constexpr auto dst_access_lengths_and_vector_length = container_push_back(
|
||||
sequence_to_tuple_of_number(dst_access_lengths), Number<DstScalarPerVector>{});
|
||||
|
||||
constexpr auto desc0 =
|
||||
make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length);
|
||||
|
||||
// 2nd stage of transforms
|
||||
constexpr auto transforms = generate_tuple(
|
||||
[&](auto i) {
|
||||
if constexpr(i == DstVectorDim)
|
||||
{
|
||||
return make_merge_transform_v3_division_mod(
|
||||
make_tuple(dst_access_lengths_and_vector_length[i],
|
||||
dst_access_lengths_and_vector_length[Number<nDim>{}]));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_pass_through_transform(dst_access_lengths_and_vector_length[i]);
|
||||
}
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
constexpr auto low_dim_idss = generate_tuple(
|
||||
[&](auto i) {
|
||||
if constexpr(i == DstVectorDim)
|
||||
{
|
||||
return Sequence<i.value, nDim>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return Sequence<i.value>{};
|
||||
}
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
constexpr auto up_dim_idss =
|
||||
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
|
||||
|
||||
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
|
||||
}
|
||||
|
||||
private:
|
||||
static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){};
|
||||
static constexpr auto src_oob_thread_scratch_desc_ =
|
||||
decltype(GetSrcThreadScratchDescriptor()){};
|
||||
static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){};
|
||||
|
||||
using SrcThreadScratch =
|
||||
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
|
||||
DstData, // apply data_convert with SrcThreadScratch
|
||||
SrcScalarPerVector,
|
||||
decltype(src_thread_scratch_desc_),
|
||||
true>;
|
||||
|
||||
using SrcOOBThreadScratch =
|
||||
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
|
||||
bool, // apply data_convert with SrcThreadScratch
|
||||
1,
|
||||
decltype(src_oob_thread_scratch_desc_),
|
||||
true>;
|
||||
|
||||
using DstThreadScratch = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
|
||||
DstData,
|
||||
DstScalarPerVector,
|
||||
decltype(dst_thread_scratch_desc_),
|
||||
true>;
|
||||
|
||||
StaticallyIndexedArray<SrcThreadScratch, NumThreadScratch> src_thread_scratch_tuple_;
|
||||
StaticallyIndexedArray<SrcOOBThreadScratch, NumThreadScratch> src_oob_thread_scratch_tuple_;
|
||||
|
||||
DstThreadScratch dst_thread_scratch_;
|
||||
|
||||
SrcCoord src_coord_;
|
||||
DstCoord dst_coord_;
|
||||
StaticallyIndexedArray<index_t, gather_num> gather_offsets_;
|
||||
const SrcElementwiseOperation src_element_op_;
|
||||
const DstElementwiseOperation dst_element_op_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
Reference in New Issue
Block a user