mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 05:31:24 +00:00
ck moe gemm implement (#1936)
* port all moe changes from ck_moe_gemm branch * refine codes in the pr * fix tail odd * fix clang format * fix clang format2 * make hot loop scheduler compatible with 16x16 and 32x32 * clang format * fix per token quant * rename moe example * clang format --------- Co-authored-by: coderfeli <coderfeli@163.com>
This commit is contained in:
@@ -141,6 +141,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
|
||||
|
||||
using Base::AMmaKStride;
|
||||
using Base::BMmaKStride;
|
||||
using Base::MWaves;
|
||||
|
||||
static constexpr index_t PrefetchStages = 2;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
@@ -184,12 +185,19 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
|
||||
{
|
||||
constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num;
|
||||
constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
|
||||
constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num;
|
||||
|
||||
constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num * MWaves;
|
||||
constexpr auto mfma_interleave = MPerXDL == 32 ? 1 : 2;
|
||||
// B global
|
||||
static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
if constexpr(MPerBlock >= 128 && NPerBlock >= 128)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 2 * mfma_interleave, 0);
|
||||
}
|
||||
else
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, mfma_interleave, 0);
|
||||
}
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
});
|
||||
|
||||
@@ -203,10 +211,10 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
|
||||
});
|
||||
|
||||
// A local
|
||||
static_for<0, num_ds_read_inst_a / 2, 1>{}([&](auto i) {
|
||||
static_for<0, num_ds_read_inst_a / 2 * mfma_interleave, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 2 / mfma_interleave, 0); // DS read
|
||||
});
|
||||
}
|
||||
|
||||
@@ -320,7 +328,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
|
||||
[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
@@ -0,0 +1,199 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_description/cluster_descriptor.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_gather.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
/**
|
||||
* @brief Blockwise data transfer
|
||||
*
|
||||
* This version does following things to avoid scratch memory issue
|
||||
* 1. Use StaticallyIndexedArray instead of C array for thread buffer
|
||||
* 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
|
||||
* 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
|
||||
*
|
||||
*/
|
||||
template <typename ThreadGroup,
|
||||
typename SrcElementwiseOperation,
|
||||
typename DstElementwiseOperation,
|
||||
InMemoryDataOperationEnum DstInMemOp,
|
||||
typename BlockSliceLengths,
|
||||
typename ThreadClusterLengths,
|
||||
typename ThreadClusterArrangeOrder,
|
||||
typename SrcData,
|
||||
typename DstData,
|
||||
typename SrcDesc,
|
||||
typename DstDesc,
|
||||
typename SrcDimAccessOrder,
|
||||
typename DstDimAccessOrder,
|
||||
index_t SrcVectorDim,
|
||||
index_t DstVectorDim,
|
||||
index_t SrcScalarPerVector,
|
||||
index_t DstScalarPerVector,
|
||||
index_t SrcScalarStrideInVector,
|
||||
index_t DstScalarStrideInVector,
|
||||
bool ThreadTransferSrcResetCoordinateAfterRun,
|
||||
bool ThreadTransferDstResetCoordinateAfterRun,
|
||||
index_t GatherDim = 1,
|
||||
index_t NumThreadScratch = 1>
|
||||
struct ThreadGroupTensorSliceTransfer_v4r1_gather
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
|
||||
static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{};
|
||||
static constexpr index_t gather_num = thread_slice_lengths.At(Number<GatherDim>{});
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
__device__ constexpr ThreadGroupTensorSliceTransfer_v4r1_gather(
|
||||
const SrcDesc& src_desc,
|
||||
const Index& src_block_slice_origin,
|
||||
const SrcElementwiseOperation& src_element_op,
|
||||
const DstDesc& dst_desc,
|
||||
const Index& dst_block_slice_origin,
|
||||
const DstElementwiseOperation& dst_element_op,
|
||||
const StaticallyIndexedArray<index_t, gather_num>& gather_offsets)
|
||||
: threadwise_transfer_(src_desc,
|
||||
make_zero_multi_index<nDim>(),
|
||||
src_element_op,
|
||||
dst_desc,
|
||||
make_zero_multi_index<nDim>(),
|
||||
dst_element_op,
|
||||
gather_offsets)
|
||||
|
||||
{
|
||||
static_assert(nDim == remove_cvref_t<SrcDesc>::GetNumOfDimension() &&
|
||||
nDim == remove_cvref_t<DstDesc>::GetNumOfDimension() &&
|
||||
nDim == ThreadClusterLengths::Size() &&
|
||||
nDim == ThreadClusterArrangeOrder::Size() &&
|
||||
nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(),
|
||||
"wrong! nDim not consistent");
|
||||
|
||||
static_assert(
|
||||
is_same<BlockSliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
|
||||
"wrong! threads should be mapped to cover entire slicing window");
|
||||
|
||||
static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
|
||||
"wrong! ThreadGroup::GetNumOfThread() too small");
|
||||
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
|
||||
make_multi_index(ThreadGroup::GetThreadId()));
|
||||
threadwise_transfer_.SetSrcSliceOrigin(
|
||||
src_desc, src_block_slice_origin + thread_cluster_idx * thread_slice_lengths);
|
||||
threadwise_transfer_.SetDstSliceOrigin(
|
||||
dst_desc, dst_block_slice_origin + thread_cluster_idx * thread_slice_lengths);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_block_slice_origin)
|
||||
{
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
|
||||
make_multi_index(ThreadGroup::GetThreadId()));
|
||||
|
||||
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
|
||||
threadwise_transfer_.SetSrcSliceOrigin(src_desc,
|
||||
src_block_slice_origin + thread_data_idx_begin);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SeqIdx, index_t ThreadScratchId = 0>
|
||||
__device__ constexpr auto GetSrcThreadScratchIdx()
|
||||
{
|
||||
return threadwise_transfer_.template GetSrcThreadScratchIdx<SeqIdx, ThreadScratchId>();
|
||||
}
|
||||
|
||||
template <typename SrcBuffer, index_t ThreadScratchId = 0>
|
||||
__device__ void RunRead(const SrcDesc& src_desc,
|
||||
const SrcBuffer& src_buf,
|
||||
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.RunRead(src_desc, src_buf, thread_scratch_id);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DstBuffer, index_t ThreadScratchId = 0>
|
||||
__device__ void RunWrite(const DstDesc& dst_desc,
|
||||
DstBuffer& dst_buf,
|
||||
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.RunWrite(dst_desc, dst_buf, thread_scratch_id);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcBuffer, typename DstBuffer, index_t ThreadScratchId>
|
||||
__device__ void Run(const SrcDesc& src_desc,
|
||||
const SrcBuffer& src_buf,
|
||||
const DstDesc& dst_desc,
|
||||
DstBuffer& dst_buf,
|
||||
Number<ThreadScratchId> thread_scratch_id)
|
||||
{
|
||||
RunRead(src_desc, src_buf, thread_scratch_id);
|
||||
RunWrite(dst_desc, dst_buf, thread_scratch_id);
|
||||
}
|
||||
|
||||
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step)
|
||||
{
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.MoveSrcSliceWindow(src_desc, step);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
|
||||
{
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
static constexpr auto thread_cluster_desc_ =
|
||||
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
using ThreadwiseTransfer =
|
||||
ThreadwiseTensorSliceTransfer_v3r1_gather<decltype(thread_slice_lengths),
|
||||
SrcElementwiseOperation,
|
||||
DstElementwiseOperation,
|
||||
DstInMemOp,
|
||||
SrcData,
|
||||
DstData,
|
||||
SrcDesc,
|
||||
DstDesc,
|
||||
SrcDimAccessOrder,
|
||||
DstDimAccessOrder,
|
||||
SrcVectorDim,
|
||||
DstVectorDim,
|
||||
SrcScalarPerVector,
|
||||
DstScalarPerVector,
|
||||
SrcScalarStrideInVector,
|
||||
DstScalarStrideInVector,
|
||||
ThreadTransferSrcResetCoordinateAfterRun,
|
||||
ThreadTransferDstResetCoordinateAfterRun,
|
||||
GatherDim,
|
||||
NumThreadScratch>;
|
||||
|
||||
ThreadwiseTransfer threadwise_transfer_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,241 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_description/cluster_descriptor.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp"
|
||||
#include "ck/utility/is_detected.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Thread-group level multi-source, multi-destination tensor slice data movement
|
||||
// Assume:
|
||||
// 1. All sources and destinations are DynamicBuffer
|
||||
// 2. Same VectorDim and ScalerPerVector for all sources and destinations
|
||||
// 3. DstInMemOps are per destination tensor
|
||||
// 4. ThreadTransferSrcResetCoordinateAfterRunFlags are per source tensor
|
||||
// 5. ThreadTransferDstResetCoordinateAfterRunFlags are per destination tensor
|
||||
//
|
||||
// Does following things to avoid scratch memory issue
|
||||
// 1. Pass tensor descritpors by reference (or tuple of references)
|
||||
// 2. Does not keep reference to tensor descriptor
|
||||
// 3. Does not construct new tensor coordinate when call Run()
|
||||
template <typename ThreadGroup,
|
||||
typename SrcDatas,
|
||||
typename DstDatas,
|
||||
typename SrcDescs,
|
||||
typename DstDescs,
|
||||
typename ElementwiseOperation,
|
||||
typename DstInMemOps, // Sequence<InMemoryDataOperationEnum ...>
|
||||
typename SliceLengths,
|
||||
typename ThreadClusterLengths,
|
||||
typename ThreadClusterArrangeOrder,
|
||||
typename SrcDimAccessOrder,
|
||||
typename DstDimAccessOrder,
|
||||
index_t SrcVectorDim,
|
||||
index_t DstVectorDim,
|
||||
typename SrcScalarPerVectors,
|
||||
index_t DstScalarPerVector,
|
||||
typename ThreadTransferSrcResetCoordinateAfterRunFlags,
|
||||
typename ThreadTransferDstResetCoordinateAfterRunFlags,
|
||||
index_t ScatterDim = 1,
|
||||
bool OutputScatter = true,
|
||||
index_t ScatterWeightIdx = 3,
|
||||
index_t NumThreadScratch = 1>
|
||||
struct ThreadGroupTensorSliceTransfer_v7r3_scatter
|
||||
{
|
||||
static constexpr index_t nDim =
|
||||
remove_cvref_t<tuple_element_t<0, SrcDescs>>::GetNumOfDimension();
|
||||
|
||||
static constexpr index_t mod_num =
|
||||
ThreadClusterLengths{}.At(Number<3>{}); // Dirty HACK FELIX, TODO fix
|
||||
static constexpr index_t nSrc = remove_cvref_t<SrcDescs>::Size();
|
||||
static constexpr index_t nDst = remove_cvref_t<DstDescs>::Size();
|
||||
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{};
|
||||
static constexpr index_t scatter_num = thread_slice_lengths.At(Number<ScatterDim>{});
|
||||
|
||||
__device__ constexpr ThreadGroupTensorSliceTransfer_v7r3_scatter(
|
||||
const SrcDescs& src_descs,
|
||||
const StaticallyIndexedArray<Index, nSrc>& src_block_slice_origins,
|
||||
const DstDescs& dst_descs,
|
||||
const StaticallyIndexedArray<Index, nDst>& dst_block_slice_origins,
|
||||
const ElementwiseOperation& element_op)
|
||||
: threadwise_transfer_(src_descs,
|
||||
StaticallyIndexedArray<Index, nSrc>{},
|
||||
dst_descs,
|
||||
StaticallyIndexedArray<Index, nDst>{},
|
||||
element_op)
|
||||
{
|
||||
static_assert(nSrc == SrcDatas::Size() && nSrc == SrcDescs::Size() &&
|
||||
nSrc == ThreadTransferSrcResetCoordinateAfterRunFlags::Size() &&
|
||||
nDst == DstDatas::Size() && nDst == DstDescs::Size() &&
|
||||
nDst == ThreadTransferDstResetCoordinateAfterRunFlags::Size(),
|
||||
"wrong!");
|
||||
|
||||
static_for<0, nSrc, 1>{}([&](auto i) {
|
||||
static_assert(
|
||||
nDim == remove_cvref_t<tuple_element_t<i.value, SrcDescs>>::GetNumOfDimension(),
|
||||
"wrong!");
|
||||
});
|
||||
|
||||
static_for<0, nDst, 1>{}([&](auto i) {
|
||||
static_assert(
|
||||
nDim == remove_cvref_t<tuple_element_t<i.value, DstDescs>>::GetNumOfDimension(),
|
||||
"wrong!");
|
||||
});
|
||||
|
||||
static_assert(nDim == ThreadClusterLengths::Size() &&
|
||||
nDim == ThreadClusterArrangeOrder::Size() &&
|
||||
nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(),
|
||||
"wrong! nDim not consistent");
|
||||
|
||||
static_assert(
|
||||
is_same<SliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
|
||||
"wrong! threads should be mapped to cover entire slicing window");
|
||||
|
||||
static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
|
||||
"wrong! ThreadGroup::GetNumOfThread() too small");
|
||||
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
const auto src_thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
|
||||
make_multi_index(ThreadGroup::GetThreadId()));
|
||||
const auto src_thread_slice_origins = generate_tuple(
|
||||
[&](auto i) {
|
||||
return src_block_slice_origins[i] +
|
||||
src_thread_cluster_idx * thread_slice_lengths;
|
||||
},
|
||||
Number<nSrc>{});
|
||||
|
||||
const auto dst_thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
|
||||
make_multi_index(OutputScatter ? ThreadGroup::GetThreadId() % mod_num
|
||||
: ThreadGroup::GetThreadId()));
|
||||
const auto dst_thread_slice_origins = generate_tuple(
|
||||
[&](auto i) {
|
||||
return dst_block_slice_origins[i] +
|
||||
dst_thread_cluster_idx * thread_slice_lengths;
|
||||
},
|
||||
Number<nDst>{});
|
||||
|
||||
threadwise_transfer_.SetSrcSliceOrigins(src_descs, src_thread_slice_origins);
|
||||
threadwise_transfer_.SetDstSliceOrigins(dst_descs, dst_thread_slice_origins);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcBuffers, index_t ThreadScratchId = 0>
|
||||
__device__ void RunRead(const SrcDescs& src_descs,
|
||||
const SrcBuffers& src_bufs,
|
||||
StaticallyIndexedArray<float, scatter_num>& scatter_weights,
|
||||
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.RunRead(src_descs, src_bufs, scatter_weights, thread_scratch_id);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
using is_tuple = decltype(std::declval<T&>().IsTuple());
|
||||
|
||||
template <typename DstBuffers, index_t ThreadScratchId = 0>
|
||||
__device__ void RunWrite(const DstDescs& dst_descs,
|
||||
DstBuffers dst_bufs,
|
||||
StaticallyIndexedArray<index_t, scatter_num>& scatter_offsets,
|
||||
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
if constexpr(is_detected<is_tuple, decltype(dst_bufs)>::value)
|
||||
threadwise_transfer_.RunWrite(
|
||||
dst_descs, dst_bufs, scatter_offsets, thread_scratch_id);
|
||||
else
|
||||
threadwise_transfer_.RunWrite(
|
||||
dst_descs, tie(dst_bufs), scatter_offsets, thread_scratch_id);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcBuffers, typename DstBuffers>
|
||||
__device__ void Run(const SrcDescs& src_descs,
|
||||
const SrcBuffers& src_bufs,
|
||||
const DstDescs& dst_descs,
|
||||
DstBuffers dst_bufs,
|
||||
StaticallyIndexedArray<index_t, scatter_num>& scatter_offsets,
|
||||
StaticallyIndexedArray<float, scatter_num>& scatter_weights)
|
||||
{
|
||||
RunRead(src_descs, src_bufs, scatter_weights);
|
||||
RunWrite(dst_descs, dst_bufs, scatter_offsets);
|
||||
}
|
||||
|
||||
template <index_t ISrc>
|
||||
__device__ void
|
||||
MoveSrcSliceWindow(const SrcDescs& src_descs, Number<ISrc> iSrc, const Index& step)
|
||||
{
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.MoveSrcSliceWindow(src_descs, iSrc, step);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void MoveSrcSliceWindow(const SrcDescs& src_descs, const Index& step)
|
||||
{
|
||||
static_for<0, SrcDescs::Size(), 1>{}(
|
||||
[&](auto i) { MoveSrcSliceWindow(src_descs, i, step); });
|
||||
}
|
||||
|
||||
template <index_t IDst>
|
||||
__device__ void
|
||||
MoveDstSliceWindow(const DstDescs& dst_descs, Number<IDst> iDst, const Index& step)
|
||||
{
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.MoveDstSliceWindow(dst_descs, iDst, step);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void MoveDstSliceWindow(const DstDescs& dst_descs, const Index& step)
|
||||
{
|
||||
static_for<0, DstDescs::Size(), 1>{}(
|
||||
[&](auto i) { MoveDstSliceWindow(dst_descs, i, step); });
|
||||
}
|
||||
|
||||
private:
|
||||
static constexpr auto thread_cluster_desc_ =
|
||||
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
using ThreadwiseTransfer =
|
||||
ThreadwiseTensorSliceTransfer_v7r3_scatter<SrcDatas,
|
||||
DstDatas,
|
||||
SrcDescs,
|
||||
DstDescs,
|
||||
ElementwiseOperation,
|
||||
DstInMemOps,
|
||||
decltype(thread_slice_lengths),
|
||||
SrcDimAccessOrder,
|
||||
DstDimAccessOrder,
|
||||
SrcVectorDim,
|
||||
DstVectorDim,
|
||||
SrcScalarPerVectors,
|
||||
DstScalarPerVector,
|
||||
ThreadTransferSrcResetCoordinateAfterRunFlags,
|
||||
ThreadTransferDstResetCoordinateAfterRunFlags,
|
||||
ScatterDim,
|
||||
OutputScatter,
|
||||
ScatterWeightIdx,
|
||||
NumThreadScratch>;
|
||||
|
||||
ThreadwiseTransfer threadwise_transfer_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
509
include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp
Normal file
509
include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp
Normal file
@@ -0,0 +1,509 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/flush_cache.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename CDataType,
|
||||
typename GemmAccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
index_t BlockSize,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t AK1,
|
||||
index_t BK1,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MXdlPerWave,
|
||||
index_t NXdlPerWave,
|
||||
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
bool ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
bool BBlockLdsExtraN,
|
||||
index_t CShuffleMXdlPerWavePerShuffle,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename CDEShuffleBlockTransferScalarPerVectors,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
|
||||
bool NSwizzle = false,
|
||||
bool IsInputGemm = true,
|
||||
typename ComputeTypeA = CDataType,
|
||||
typename ComputeTypeB = ComputeTypeA,
|
||||
typename LDSTypeA = ComputeTypeA,
|
||||
typename LDSTypeB = ComputeTypeB>
|
||||
struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
CDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation>
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
using GridwiseGemm =
|
||||
GridwiseMoeGemm<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
GemmAccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
CDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
GemmSpec,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
false,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
false,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEShuffleBlockTransferScalarPerVectors,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
NSwizzle,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB,
|
||||
LDSTypeA,
|
||||
LDSTypeB>;
|
||||
|
||||
using Argument = typename GridwiseGemm::Argument;
|
||||
|
||||
int GetPreShuffleParameters() override { return NPerXDL; }
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
arg.Print();
|
||||
}
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg))
|
||||
{
|
||||
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
|
||||
}
|
||||
|
||||
index_t gdx, gdy, gdz;
|
||||
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
index_t k_grain = arg.KBatch * KPerBlock;
|
||||
index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
|
||||
|
||||
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
|
||||
|
||||
const auto RunKernel = [&](const auto& kernel) {
|
||||
if(stream_config.flush_cache)
|
||||
{
|
||||
|
||||
std::array<std::size_t, NumDTensor> DsSize;
|
||||
|
||||
Argument arg_ = arg;
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
|
||||
arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
|
||||
const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
|
||||
arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
|
||||
|
||||
auto size_a_buffer =
|
||||
a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType);
|
||||
auto size_b_buffer =
|
||||
b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType);
|
||||
|
||||
const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N(
|
||||
arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs);
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType);
|
||||
});
|
||||
ck::utility::RotatingMemWrapperMultiD<Argument, DsDataType> rotating_mem(
|
||||
arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer, DsSize);
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck::utility::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(arg_.KBatch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
|
||||
0,
|
||||
arg_.M * arg_.N * sizeof(CDataType),
|
||||
stream_config.stream_id_));
|
||||
};
|
||||
|
||||
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
|
||||
stream_config,
|
||||
run_flush_cache,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg_);
|
||||
}
|
||||
else
|
||||
{
|
||||
if(arg.KBatch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
|
||||
0,
|
||||
arg.M * arg.N * sizeof(CDataType),
|
||||
stream_config.stream_id_));
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
|
||||
}
|
||||
};
|
||||
|
||||
constexpr auto estimated_reg_a = MPerBlock * KPerBlock * sizeof(ADataType) / BlockSize /
|
||||
4 * (1 + GridwiseGemm::NWave);
|
||||
constexpr auto estimated_reg_b =
|
||||
NPerBlock * KPerBlock * sizeof(BDataType) / BlockSize / 4 * (2);
|
||||
constexpr auto estimated_reg_c =
|
||||
MPerBlock * NPerBlock * sizeof(GemmAccDataType) / BlockSize / 4;
|
||||
constexpr auto estimated_reg_total =
|
||||
estimated_reg_a + estimated_reg_b + estimated_reg_c;
|
||||
|
||||
constexpr index_t minimum_occupancy = (estimated_reg_total >= 256) ? 1 : 2;
|
||||
|
||||
constexpr auto MemoryDataOp =
|
||||
IsInputGemm ? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd;
|
||||
if(has_main_k_block_loop)
|
||||
{
|
||||
// Tail number always full
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel = kernel_moe_gemm<GridwiseGemm,
|
||||
true,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
IsInputGemm,
|
||||
TailNumber::Odd>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_moe_gemm<GridwiseGemm,
|
||||
true,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
IsInputGemm,
|
||||
TailNumber::Even>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 ||
|
||||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel = kernel_moe_gemm_2lds<GridwiseGemm,
|
||||
true,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
IsInputGemm,
|
||||
TailNumber::Odd>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_moe_gemm_2lds<GridwiseGemm,
|
||||
true,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
IsInputGemm,
|
||||
TailNumber::Even>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("todo: only v1 & v2 support now");
|
||||
}
|
||||
}
|
||||
#if 1
|
||||
else
|
||||
{
|
||||
// Tail number always 1
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
const auto kernel = kernel_moe_gemm<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
IsInputGemm,
|
||||
TailNumber::Odd>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
// only impl kbatch 1 now
|
||||
if(arg.KBatch > 1)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if(!ck::is_xdl_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t> && arg.KBatch > 1)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
|
||||
GemmSpec == GemmSpecialization::NKPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding ||
|
||||
GemmSpec == GemmSpecialization::KPadding))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return GridwiseGemm::CheckValidity(arg);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(const void* p_sorted_token_ids,
|
||||
const void* p_sorted_expert_ids,
|
||||
const void* p_max_token_id,
|
||||
const void* p_a,
|
||||
const void* p_b,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_c,
|
||||
index_t NumTokens,
|
||||
index_t TopK,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
std::array<index_t, NumDTensor> StrideDs,
|
||||
index_t StrideC,
|
||||
index_t KBatch,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
{
|
||||
return Argument{static_cast<const index_t*>(p_sorted_token_ids),
|
||||
static_cast<const index_t*>(p_sorted_expert_ids),
|
||||
static_cast<const index_t*>(p_max_token_id),
|
||||
static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
p_ds,
|
||||
static_cast<CDataType*>(p_c),
|
||||
NumTokens,
|
||||
TopK,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideDs,
|
||||
StrideC,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_c,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
std::array<ck::index_t, NumDTensor> StrideDs,
|
||||
index_t StrideC,
|
||||
index_t KBatch,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
p_ds,
|
||||
static_cast<CDataType*>(p_c),
|
||||
M, // randoms set, no use
|
||||
0,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideDs,
|
||||
StrideC,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
|
||||
{BlockGemmPipelineScheduler::Intrawave, "Intrawave"},
|
||||
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
|
||||
|
||||
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
|
||||
{BlockGemmPipelineVersion::v1, "v1"}, {BlockGemmPipelineVersion::v2, "v2"}};
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceMoeGEmm"
|
||||
<< "<"
|
||||
<< getGemmSpecializationString(GemmSpec) << ", "
|
||||
<< std::string(ALayout::name)[0]
|
||||
<< std::string(BLayout::name)[0]
|
||||
<< std::string(CLayout::name)[0]
|
||||
<< ">"
|
||||
<< " BlkSize: "
|
||||
<< BlockSize << ", "
|
||||
<< "BlkTile: "
|
||||
<< MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
|
||||
<< "WaveTile: "
|
||||
<< MPerXDL<<"x"<<NPerXDL << ", "
|
||||
<< "WaveMap: "
|
||||
<< MXdlPerWave<<"x" << NXdlPerWave<<", "
|
||||
<< "VmemReadVec: "
|
||||
<< ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
|
||||
<< "BlkGemmPipelineScheduler: "
|
||||
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
|
||||
<< "BlkGemmPipelineVersion: "
|
||||
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
|
||||
<< "BlkGemmPipelinePrefetchStages: "
|
||||
<< GridwiseGemm::BlockwiseGemmPipe::PrefetchStages;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
2144
include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp
Normal file
2144
include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,903 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
|
||||
#include "ck/tensor/static_tensor.hpp"
|
||||
#include "ck/utility/is_detected.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Assume:
|
||||
// 1. src_desc and dst_desc are not known at compile-time
|
||||
// 2. SrcBuffer and DstBuffer are DynamicBuffer
|
||||
// 3. src_slice_origin and dst_slice_origin are not known at compile-time,
|
||||
// 4. Use thread buffer
|
||||
template <typename SliceLengths,
|
||||
typename SrcElementwiseOperation,
|
||||
typename DstElementwiseOperation,
|
||||
InMemoryDataOperationEnum DstInMemOp,
|
||||
typename SrcData,
|
||||
typename DstData,
|
||||
typename SrcDesc,
|
||||
typename DstDesc,
|
||||
typename SrcDimAccessOrder,
|
||||
typename DstDimAccessOrder,
|
||||
index_t SrcVectorDim,
|
||||
index_t DstVectorDim,
|
||||
index_t SrcScalarPerVector,
|
||||
index_t DstScalarPerVector,
|
||||
index_t SrcScalarStrideInVector,
|
||||
index_t DstScalarStrideInVector,
|
||||
bool SrcResetCoordinateAfterRun, // control whether to move back src coordinate after each
|
||||
// RunRead(), will be fused with MoveSrcSliceWindow to
|
||||
// save addr computation
|
||||
bool DstResetCoordinateAfterRun, // control whether to move back dst coordinate after each
|
||||
// RunWrite(), will be fused with MoveDstSliceWindow to
|
||||
// save addr computation
|
||||
index_t GatherDim = 1,
|
||||
index_t NumThreadScratch = 1>
|
||||
struct ThreadwiseTensorSliceTransfer_v3r1_gather
|
||||
{
|
||||
static constexpr index_t nDim = SliceLengths::Size();
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
|
||||
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
|
||||
|
||||
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
|
||||
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr index_t gather_num = SliceLengths{}.At(Number<GatherDim>{});
|
||||
|
||||
__device__ constexpr ThreadwiseTensorSliceTransfer_v3r1_gather(
|
||||
const SrcDesc& src_desc,
|
||||
const Index& src_slice_origin,
|
||||
const SrcElementwiseOperation& src_element_op,
|
||||
const DstDesc& dst_desc,
|
||||
const Index& dst_slice_origin,
|
||||
const DstElementwiseOperation& dst_element_op,
|
||||
const StaticallyIndexedArray<index_t, gather_num>& gather_offsets)
|
||||
: src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)),
|
||||
dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)),
|
||||
src_element_op_(src_element_op),
|
||||
dst_element_op_(dst_element_op),
|
||||
gather_offsets_(gather_offsets)
|
||||
{
|
||||
}
|
||||
|
||||
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
|
||||
{
|
||||
|
||||
auto adjusted_origin_idx = [&]() {
|
||||
Index idx;
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
idx(i) = i.value == GatherDim ? 0 : src_slice_origin_idx[Number<i>{}];
|
||||
});
|
||||
return idx;
|
||||
}();
|
||||
src_coord_ = make_tensor_coordinate(src_desc, adjusted_origin_idx);
|
||||
}
|
||||
|
||||
__device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
|
||||
{
|
||||
dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx);
|
||||
}
|
||||
|
||||
template <typename SrcBuffer, index_t ThreadScratchId = 0>
|
||||
__device__ void RunRead(const SrcDesc& src_desc,
|
||||
const SrcBuffer& src_buf,
|
||||
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Global or
|
||||
SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Lds,
|
||||
"wrong!");
|
||||
|
||||
static_assert(
|
||||
is_same<remove_cvref_t<typename SrcBuffer::type>, remove_cvref_t<SrcData>>::value,
|
||||
"wrong! SrcBuffer and SrcData data type are inconsistent");
|
||||
|
||||
// scalar per access on each dim
|
||||
// TODO: don't use lambda_scalar_per_access
|
||||
constexpr auto src_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
|
||||
static_assert(SliceLengths::At(SrcVectorDim) % SrcScalarPerVector == 0,
|
||||
"SliceLengths[SrcVectorDim] must be divisible by SrcScalarPerVector");
|
||||
|
||||
constexpr auto src_dim_access_order = SrcDimAccessOrder{};
|
||||
constexpr auto ordered_gather_dim = src_dim_access_order[GatherDim];
|
||||
constexpr auto ordered_src_access_lengths =
|
||||
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
|
||||
|
||||
// make forward steps
|
||||
const auto src_forward_steps = generate_tuple(
|
||||
[&](auto i) {
|
||||
Index forward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0;
|
||||
});
|
||||
|
||||
return make_tensor_coordinate_step(src_desc, forward_step_idx);
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
// make backward steps
|
||||
const auto src_backward_steps = generate_tuple(
|
||||
[&](auto i) {
|
||||
Index backward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0;
|
||||
});
|
||||
|
||||
return make_tensor_coordinate_step(src_desc, backward_step_idx);
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
// loop over tensor and copy
|
||||
static_ford<decltype(ordered_src_access_lengths)>{}([&](auto ordered_src_access_idx) {
|
||||
// judge move forward or move backward
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep_;
|
||||
|
||||
forward_sweep_(I0) = true;
|
||||
|
||||
static_for<1, nDim, 1>{}([&](auto i) {
|
||||
index_t tmp = ordered_src_access_idx[I0];
|
||||
|
||||
static_for<1, i, 1>{}([&](auto j) {
|
||||
tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j];
|
||||
});
|
||||
|
||||
forward_sweep_(i) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep_;
|
||||
}();
|
||||
|
||||
// calculate src data index
|
||||
constexpr auto src_data_idx = [&]() {
|
||||
Index ordered_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i]
|
||||
: ordered_src_access_lengths[i] - 1 -
|
||||
ordered_src_access_idx[i];
|
||||
});
|
||||
|
||||
return container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
|
||||
src_scalar_per_access;
|
||||
}();
|
||||
|
||||
constexpr auto src_data_idx_seq = generate_sequence_v2(
|
||||
[&](auto i) { return Number<src_data_idx[i]>{}; }, Number<src_data_idx.Size()>{});
|
||||
|
||||
auto gather_offset =
|
||||
gather_offsets_(ordered_src_access_idx[Number<ordered_gather_dim>{}]);
|
||||
|
||||
// maintain a container record is_src_valid, waiting for RunWrite use.
|
||||
const index_t ld_offset = src_coord_.GetOffset() + gather_offset;
|
||||
const bool is_src_valid =
|
||||
ld_offset <
|
||||
src_desc
|
||||
.GetElementSpaceSize(); // hack felix, todo use coord
|
||||
// coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc,
|
||||
// src_coord_) && (gather_offset < 32*512);
|
||||
src_oob_thread_scratch_tuple_(thread_scratch_id)
|
||||
.template SetAsType<bool>(src_data_idx_seq, is_src_valid);
|
||||
|
||||
using src_vector_type = vector_type_maker_t<SrcData, SrcScalarPerVector>;
|
||||
using src_vector_t = typename src_vector_type::type;
|
||||
|
||||
auto src_vector_container =
|
||||
src_vector_type{src_buf.template Get<src_vector_t>(ld_offset, true)};
|
||||
|
||||
using dst_vector_type = vector_type_maker_t<DstData, SrcScalarPerVector>;
|
||||
using dst_vector_t = typename dst_vector_type::type;
|
||||
dst_vector_type op_r_v;
|
||||
|
||||
constexpr auto get_elem_op_vec_len = []() {
|
||||
if constexpr(is_detected<is_pack8_invocable_t, decltype(src_element_op_)>::value)
|
||||
{
|
||||
if constexpr(decltype(src_element_op_)::is_pack8_invocable)
|
||||
return math::min(8, SrcScalarPerVector);
|
||||
}
|
||||
if constexpr(is_detected<is_pack4_invocable_t, decltype(src_element_op_)>::value)
|
||||
{
|
||||
if constexpr(decltype(src_element_op_)::is_pack4_invocable)
|
||||
return math::min(4, SrcScalarPerVector);
|
||||
}
|
||||
if constexpr(is_detected<is_pack2_invocable_t, decltype(src_element_op_)>::value)
|
||||
{
|
||||
if constexpr(decltype(src_element_op_)::is_pack2_invocable)
|
||||
return math::min(2, SrcScalarPerVector);
|
||||
}
|
||||
return 1;
|
||||
};
|
||||
|
||||
constexpr index_t elem_op_vec_len = get_elem_op_vec_len();
|
||||
|
||||
using src_elem_op_vec_t = typename vector_type<SrcData, elem_op_vec_len>::type;
|
||||
using dst_elem_op_vec_t = typename vector_type<DstData, elem_op_vec_len>::type;
|
||||
|
||||
static_for<0, SrcScalarPerVector / elem_op_vec_len, 1>{}([&](auto idx) {
|
||||
// apply the src elementwise op and convert to DstData under the hood if needed
|
||||
src_element_op_(op_r_v.template AsType<dst_elem_op_vec_t>()(idx),
|
||||
src_vector_container.template AsType<src_elem_op_vec_t>()[idx]);
|
||||
});
|
||||
|
||||
// copy data from src_vector_container into src_thread_scratch_
|
||||
src_thread_scratch_tuple_(thread_scratch_id)
|
||||
.template SetAsType<dst_vector_t>(src_data_idx_seq,
|
||||
op_r_v.template AsType<dst_vector_t>()[I0]);
|
||||
|
||||
auto move_on_dim = [&]() constexpr
|
||||
{
|
||||
StaticallyIndexedArray<bool, nDim> move_on_dim_;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
move_on_dim_(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1;
|
||||
|
||||
static_for<i + 1, nDim, 1>{}([&](auto j) {
|
||||
move_on_dim_(i) &=
|
||||
ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1;
|
||||
});
|
||||
move_on_dim_(i) &= i.value != ordered_gather_dim;
|
||||
});
|
||||
|
||||
return move_on_dim_;
|
||||
}
|
||||
();
|
||||
// move src coord
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
if(move_on_dim[i])
|
||||
{
|
||||
if constexpr(forward_sweep[i])
|
||||
{
|
||||
move_tensor_coordinate(
|
||||
src_desc, src_coord_, src_forward_steps[src_dim_access_order[i]]);
|
||||
}
|
||||
else
|
||||
{
|
||||
move_tensor_coordinate(
|
||||
src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]);
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// move src coordinate back to slice origin (or not)
|
||||
if constexpr(SrcResetCoordinateAfterRun)
|
||||
{
|
||||
const auto src_reset_step =
|
||||
make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep());
|
||||
|
||||
move_tensor_coordinate(src_desc, src_coord_, src_reset_step);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SeqIdx, index_t ThreadScratchId = 0>
|
||||
__device__ constexpr auto
|
||||
GetSrcThreadScratchIdx(Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
using vector_t = typename vector_type_maker<SrcData, SrcScalarPerVector>::type::type;
|
||||
return src_thread_scratch_tuple_(thread_scratch_id).template GetAsType<vector_t>(SeqIdx{});
|
||||
}
|
||||
|
||||
template <index_t ThreadScratchId>
|
||||
__device__ void
|
||||
TransferDataFromSrcThreadScratchToDstThreadScratch(Number<ThreadScratchId> thread_scratch_id)
|
||||
{
|
||||
#if !CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE
|
||||
static_ford<SliceLengths>{}([&](auto idx) {
|
||||
dst_thread_scratch_(idx) = src_thread_scratch_tuple_[thread_scratch_id][idx];
|
||||
});
|
||||
#else
|
||||
|
||||
// OOB Check
|
||||
constexpr auto src_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
|
||||
|
||||
constexpr auto src_dim_access_order = SrcDimAccessOrder{};
|
||||
|
||||
constexpr auto ordered_src_access_lengths =
|
||||
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
|
||||
|
||||
// loop over tensor and copy
|
||||
static_ford<decltype(ordered_src_access_lengths)>{}([&](auto ordered_src_access_idx) {
|
||||
// judge move forward or move backward
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep_;
|
||||
|
||||
forward_sweep_(I0) = true;
|
||||
|
||||
static_for<1, nDim, 1>{}([&](auto i) {
|
||||
index_t tmp = ordered_src_access_idx[I0];
|
||||
|
||||
static_for<1, i, 1>{}([&](auto j) {
|
||||
tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j];
|
||||
});
|
||||
|
||||
forward_sweep_(i) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep_;
|
||||
}();
|
||||
|
||||
// calculate src data index
|
||||
constexpr auto src_data_idx = [&]() {
|
||||
Index ordered_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i]
|
||||
: ordered_src_access_lengths[i] - 1 -
|
||||
ordered_src_access_idx[i];
|
||||
});
|
||||
|
||||
return container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
|
||||
src_scalar_per_access;
|
||||
}();
|
||||
|
||||
constexpr auto src_data_idx_seq = generate_sequence_v2(
|
||||
[&](auto i) { return Number<src_data_idx[i]>{}; }, Number<src_data_idx.Size()>{});
|
||||
|
||||
using vector_t = typename vector_type_maker<DstData, SrcScalarPerVector>::type::type;
|
||||
|
||||
auto op_r = src_thread_scratch_tuple_(thread_scratch_id)
|
||||
.template GetAsType<vector_t>(src_data_idx_seq);
|
||||
|
||||
const bool is_src_valid = src_oob_thread_scratch_tuple_(thread_scratch_id)
|
||||
.template GetAsType<bool>(src_data_idx_seq);
|
||||
|
||||
auto op_r_v = is_src_valid ? op_r : vector_t(0);
|
||||
|
||||
src_thread_scratch_tuple_(thread_scratch_id)
|
||||
.template SetAsType<vector_t>(src_data_idx_seq, op_r_v);
|
||||
});
|
||||
|
||||
// sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_
|
||||
// TODO make this logic more generic for more sub-dword datatype
|
||||
if constexpr(SrcVectorDim != DstVectorDim &&
|
||||
((is_same<half_t, remove_cvref_t<DstData>>::value &&
|
||||
SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0) ||
|
||||
(is_same<int8_t, remove_cvref_t<DstData>>::value &&
|
||||
SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0) ||
|
||||
(is_same<f8_t, remove_cvref_t<DstData>>::value &&
|
||||
SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0)))
|
||||
{
|
||||
// each transpose does
|
||||
// DstScalarPerVector # of src vectors in src_thread_scratch_
|
||||
// SrcScalarPerVector # of dst vectors in dst_thread_scratch_
|
||||
constexpr index_t num_src_vector = Number<DstScalarPerVector>{};
|
||||
constexpr index_t num_dst_vector = Number<SrcScalarPerVector>{};
|
||||
|
||||
// Assume SrcVectorDim is not the same as DstVectorDim, so we do transpose
|
||||
// TODO: make this logic generic for all scenario
|
||||
static_assert(SrcVectorDim != DstVectorDim, "wrong");
|
||||
|
||||
constexpr auto src_scalar_step_in_vector = generate_sequence(
|
||||
detail::lambda_scalar_step_in_vector<SrcVectorDim>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto dst_scalar_step_in_vector = generate_sequence(
|
||||
detail::lambda_scalar_step_in_vector<DstVectorDim>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access_for_src_and_dst<SrcVectorDim,
|
||||
SrcScalarPerVector,
|
||||
DstVectorDim,
|
||||
DstScalarPerVector>{},
|
||||
Number<nDim>{});
|
||||
|
||||
constexpr auto access_lengths = SliceLengths{} / scalar_per_access;
|
||||
|
||||
static_ford<decltype(access_lengths)>{}([&](auto access_idx) {
|
||||
constexpr auto data_idx = access_idx * scalar_per_access;
|
||||
|
||||
constexpr auto data_idx_seq = generate_sequence_v2(
|
||||
[&](auto i) { return Number<data_idx[i]>{}; }, Number<nDim>{});
|
||||
|
||||
using src_vector_t = vector_type_maker_t<DstData, SrcScalarPerVector>;
|
||||
using dst_vector_t = vector_type_maker_t<DstData, DstScalarPerVector>;
|
||||
|
||||
// get DstScalarPerVector # of read-only references to src vectors from
|
||||
// src_thread_scratch_
|
||||
const auto src_vector_refs = generate_tie(
|
||||
[&](auto i) -> const src_vector_t& {
|
||||
// i increment corresponds to movement in DstVectorDim
|
||||
return src_thread_scratch_tuple_[thread_scratch_id].GetVectorTypeReference(
|
||||
data_idx_seq + i * dst_scalar_step_in_vector);
|
||||
},
|
||||
Number<num_src_vector>{});
|
||||
|
||||
// get SrcScalarPerVector # of references to dst vectors from dst_thread_scratch_
|
||||
auto dst_vector_refs = generate_tie(
|
||||
[&](auto i) -> dst_vector_t& {
|
||||
// i increment corresponds to movement in SrcVectorDim
|
||||
return dst_thread_scratch_.GetVectorTypeReference(
|
||||
data_idx_seq + i * src_scalar_step_in_vector);
|
||||
},
|
||||
Number<num_dst_vector>{});
|
||||
|
||||
// do data transpose
|
||||
transpose_vectors<DstData, DstScalarPerVector, SrcScalarPerVector>{}(
|
||||
src_vector_refs, dst_vector_refs);
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_ford<SliceLengths>{}([&](auto idx) {
|
||||
dst_thread_scratch_(idx) = src_thread_scratch_tuple_[thread_scratch_id][idx];
|
||||
});
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename DstBuffer, index_t ThreadScratchId = 0>
|
||||
__device__ void RunWrite(const DstDesc& dst_desc,
|
||||
DstBuffer& dst_buf,
|
||||
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
// if there is transpose, it's done here
|
||||
// if there is oob check, it's done here
|
||||
// TODO move this elsewhere
|
||||
TransferDataFromSrcThreadScratchToDstThreadScratch(thread_scratch_id);
|
||||
|
||||
static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Global or
|
||||
DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds,
|
||||
"wrong!");
|
||||
|
||||
static_assert(
|
||||
is_same<remove_cvref_t<typename DstBuffer::type>, remove_cvref_t<DstData>>::value,
|
||||
"wrong! SrcBuffer or DstBuffer data type is wrong");
|
||||
|
||||
// src scalar per access on each dim
|
||||
// TODO: don't use this
|
||||
constexpr auto dst_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
|
||||
|
||||
constexpr auto dst_dim_access_order = DstDimAccessOrder{};
|
||||
|
||||
constexpr auto ordered_dst_access_lengths =
|
||||
container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
|
||||
|
||||
// make forward steps
|
||||
const auto dst_forward_steps = generate_tuple(
|
||||
[&](auto i) {
|
||||
Index forward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0;
|
||||
});
|
||||
|
||||
return make_tensor_coordinate_step(dst_desc, forward_step_idx);
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
// make backward steps
|
||||
const auto dst_backward_steps = generate_tuple(
|
||||
[&](auto i) {
|
||||
Index backward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0;
|
||||
});
|
||||
|
||||
return make_tensor_coordinate_step(dst_desc, backward_step_idx);
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
// loop over tensor and copy
|
||||
static_ford<decltype(ordered_dst_access_lengths)>{}([&](auto ordered_dst_access_idx) {
|
||||
// judge move forward or move backward
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep_;
|
||||
|
||||
forward_sweep_(I0) = true;
|
||||
|
||||
static_for<1, nDim, 1>{}([&](auto i) {
|
||||
index_t tmp = ordered_dst_access_idx[I0];
|
||||
|
||||
static_for<1, i, 1>{}([&](auto j) {
|
||||
tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j];
|
||||
});
|
||||
|
||||
forward_sweep_(i) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep_;
|
||||
}();
|
||||
|
||||
// calculate dst data index
|
||||
constexpr auto dst_data_idx = [&]() {
|
||||
Index ordered_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_idx[i]
|
||||
: ordered_dst_access_lengths[i] - 1 -
|
||||
ordered_dst_access_idx[i];
|
||||
});
|
||||
|
||||
return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) *
|
||||
dst_scalar_per_access;
|
||||
}();
|
||||
|
||||
constexpr auto dst_data_idx_seq = generate_sequence_v2(
|
||||
[&](auto i) { return Number<dst_data_idx[i]>{}; }, Number<dst_data_idx.Size()>{});
|
||||
|
||||
const bool is_dst_valid =
|
||||
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
|
||||
|
||||
using dst_vector_type = vector_type_maker_t<DstData, DstScalarPerVector>;
|
||||
using dst_vector_t = typename dst_vector_type::type;
|
||||
|
||||
// copy data from dst_thread_scratch_ into dst_vector_container
|
||||
auto dst_vector_container = dst_vector_type{
|
||||
dst_thread_scratch_.template GetAsType<dst_vector_t>(dst_data_idx_seq)};
|
||||
|
||||
static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
|
||||
DstData dst_v;
|
||||
|
||||
// apply DstElementwiseOperation
|
||||
dst_element_op_(dst_v, dst_vector_container.template AsType<DstData>()[i]);
|
||||
|
||||
dst_vector_container.template AsType<DstData>()(i) = dst_v;
|
||||
});
|
||||
|
||||
// copy data from dst_vector_container to dst_buf
|
||||
dst_buf.template Set<dst_vector_t>(
|
||||
dst_coord_.GetOffset(),
|
||||
is_dst_valid,
|
||||
dst_vector_container.template AsType<dst_vector_t>()[I0]);
|
||||
|
||||
constexpr auto move_on_dim = [&]() constexpr
|
||||
{
|
||||
StaticallyIndexedArray<bool, nDim> move_on_dim_;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
move_on_dim_(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1;
|
||||
|
||||
static_for<i + 1, nDim, 1>{}([&](auto j) {
|
||||
move_on_dim_(i) &=
|
||||
ordered_dst_access_idx[j] == ordered_dst_access_lengths[j] - 1;
|
||||
});
|
||||
});
|
||||
|
||||
return move_on_dim_;
|
||||
}
|
||||
();
|
||||
|
||||
// move dst coord
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
if constexpr(move_on_dim[i])
|
||||
{
|
||||
if constexpr(forward_sweep[i])
|
||||
{
|
||||
move_tensor_coordinate(
|
||||
dst_desc, dst_coord_, dst_forward_steps[dst_dim_access_order[i]]);
|
||||
}
|
||||
else
|
||||
{
|
||||
move_tensor_coordinate(
|
||||
dst_desc, dst_coord_, dst_backward_steps[dst_dim_access_order[i]]);
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// move dst coordinate back to slice origin (or not)
|
||||
if constexpr(DstResetCoordinateAfterRun)
|
||||
{
|
||||
const auto dst_reset_step =
|
||||
make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep());
|
||||
|
||||
move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetSrcCoordinateResetStep()
|
||||
{
|
||||
// scalar per access on each dim
|
||||
// TODO: don't use lambda_scalar_per_access
|
||||
constexpr auto src_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
|
||||
|
||||
constexpr auto src_dim_access_order = SrcDimAccessOrder{};
|
||||
|
||||
constexpr auto ordered_src_access_lengths =
|
||||
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
|
||||
|
||||
// judge move forward or move backward during the last iteration
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep_;
|
||||
|
||||
forward_sweep_(I0) = true;
|
||||
|
||||
static_for<1, nDim, 1>{}([&](auto i) {
|
||||
index_t tmp = ordered_src_access_lengths[I0] - 1;
|
||||
|
||||
static_for<1, i, 1>{}([&](auto j) {
|
||||
tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1;
|
||||
});
|
||||
|
||||
forward_sweep_(i) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep_;
|
||||
}();
|
||||
|
||||
// calculate src data index after last iteration in RunRead(), if it has not being reset by
|
||||
// RunRead()
|
||||
constexpr auto src_data_idx = [&]() {
|
||||
Index ordered_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0;
|
||||
});
|
||||
|
||||
return container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
|
||||
src_scalar_per_access;
|
||||
}();
|
||||
|
||||
//
|
||||
constexpr auto reset_src_data_step = [&]() {
|
||||
Index reset_src_data_step_;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
reset_src_data_step_(i) = i.value == GatherDim ? 0 : -src_data_idx[i];
|
||||
});
|
||||
|
||||
return reset_src_data_step_;
|
||||
}();
|
||||
return reset_src_data_step;
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetDstCoordinateResetStep()
|
||||
{
|
||||
// scalar per access on each dim
|
||||
// TODO: don't use lambda_scalar_per_access
|
||||
constexpr auto dst_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
|
||||
|
||||
constexpr auto dst_dim_access_order = DstDimAccessOrder{};
|
||||
|
||||
constexpr auto ordered_dst_access_lengths =
|
||||
container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
|
||||
|
||||
// judge move forward or move backward during the last iteration
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep_;
|
||||
|
||||
forward_sweep_(I0) = true;
|
||||
|
||||
static_for<1, nDim, 1>{}([&](auto i) {
|
||||
index_t tmp = ordered_dst_access_lengths[I0] - 1;
|
||||
|
||||
static_for<1, i, 1>{}([&](auto j) {
|
||||
tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1;
|
||||
});
|
||||
|
||||
forward_sweep_(i) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep_;
|
||||
}();
|
||||
|
||||
// calculate dst data index after last iteration in RunWrite(), if it has not being reset by
|
||||
// RunWrite()
|
||||
constexpr auto dst_data_idx = [&]() {
|
||||
Index ordered_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0;
|
||||
});
|
||||
|
||||
return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) *
|
||||
dst_scalar_per_access;
|
||||
}();
|
||||
|
||||
//
|
||||
constexpr auto reset_dst_data_step = [&]() {
|
||||
Index reset_dst_data_step_;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; });
|
||||
|
||||
return reset_dst_data_step_;
|
||||
}();
|
||||
|
||||
return reset_dst_data_step;
|
||||
}
|
||||
|
||||
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
|
||||
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc,
|
||||
const Index& src_slice_origin_step_idx)
|
||||
{
|
||||
// if src coord was not reset by RunRead(), then need to adjust the step here
|
||||
const auto adjusted_step_idx =
|
||||
SrcResetCoordinateAfterRun ? src_slice_origin_step_idx
|
||||
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
|
||||
// is it OK to construct a new step every time?
|
||||
const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx);
|
||||
|
||||
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
|
||||
}
|
||||
|
||||
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
|
||||
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc,
|
||||
const Index& dst_slice_origin_step_idx)
|
||||
{
|
||||
// if dst coord was not reset by RunWrite(), then need to adjust the step here
|
||||
const auto adjusted_step_idx =
|
||||
DstResetCoordinateAfterRun ? dst_slice_origin_step_idx
|
||||
: dst_slice_origin_step_idx + GetDstCoordinateResetStep();
|
||||
|
||||
// is it OK to construct a new step every time?
|
||||
const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx);
|
||||
|
||||
move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetSrcThreadScratchDescriptor()
|
||||
{
|
||||
constexpr auto src_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
|
||||
|
||||
constexpr auto src_access_lengths_and_vector_length = container_push_back(
|
||||
sequence_to_tuple_of_number(src_access_lengths), Number<SrcScalarPerVector>{});
|
||||
|
||||
// 1st stage of transforms
|
||||
constexpr auto desc0 =
|
||||
make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length);
|
||||
|
||||
// 2nd stage of transforms
|
||||
constexpr auto transforms = generate_tuple(
|
||||
[&](auto i) {
|
||||
if constexpr(i == SrcVectorDim)
|
||||
{
|
||||
return make_merge_transform_v3_division_mod(
|
||||
make_tuple(src_access_lengths_and_vector_length[i],
|
||||
src_access_lengths_and_vector_length[Number<nDim>{}]));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_pass_through_transform(src_access_lengths_and_vector_length[i]);
|
||||
}
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
constexpr auto low_dim_idss = generate_tuple(
|
||||
[&](auto i) {
|
||||
if constexpr(i == SrcVectorDim)
|
||||
{
|
||||
return Sequence<i.value, nDim>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return Sequence<i.value>{};
|
||||
}
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
constexpr auto up_dim_idss =
|
||||
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
|
||||
|
||||
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetSrcOOBThreadScratchDescriptor()
|
||||
{
|
||||
constexpr auto src_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
|
||||
|
||||
return make_naive_tensor_descriptor_packed(src_access_lengths);
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetDstThreadScratchDescriptor()
|
||||
{
|
||||
// 1st stage of transforms
|
||||
constexpr auto dst_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
|
||||
|
||||
constexpr auto dst_access_lengths_and_vector_length = container_push_back(
|
||||
sequence_to_tuple_of_number(dst_access_lengths), Number<DstScalarPerVector>{});
|
||||
|
||||
constexpr auto desc0 =
|
||||
make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length);
|
||||
|
||||
// 2nd stage of transforms
|
||||
constexpr auto transforms = generate_tuple(
|
||||
[&](auto i) {
|
||||
if constexpr(i == DstVectorDim)
|
||||
{
|
||||
return make_merge_transform_v3_division_mod(
|
||||
make_tuple(dst_access_lengths_and_vector_length[i],
|
||||
dst_access_lengths_and_vector_length[Number<nDim>{}]));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_pass_through_transform(dst_access_lengths_and_vector_length[i]);
|
||||
}
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
constexpr auto low_dim_idss = generate_tuple(
|
||||
[&](auto i) {
|
||||
if constexpr(i == DstVectorDim)
|
||||
{
|
||||
return Sequence<i.value, nDim>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return Sequence<i.value>{};
|
||||
}
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
constexpr auto up_dim_idss =
|
||||
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
|
||||
|
||||
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
|
||||
}
|
||||
|
||||
private:
|
||||
static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){};
|
||||
static constexpr auto src_oob_thread_scratch_desc_ =
|
||||
decltype(GetSrcThreadScratchDescriptor()){};
|
||||
static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){};
|
||||
|
||||
using SrcThreadScratch =
|
||||
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
|
||||
DstData, // apply data_convert with SrcThreadScratch
|
||||
SrcScalarPerVector,
|
||||
decltype(src_thread_scratch_desc_),
|
||||
true>;
|
||||
|
||||
using SrcOOBThreadScratch =
|
||||
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
|
||||
bool, // apply data_convert with SrcThreadScratch
|
||||
1,
|
||||
decltype(src_oob_thread_scratch_desc_),
|
||||
true>;
|
||||
|
||||
using DstThreadScratch = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
|
||||
DstData,
|
||||
DstScalarPerVector,
|
||||
decltype(dst_thread_scratch_desc_),
|
||||
true>;
|
||||
|
||||
StaticallyIndexedArray<SrcThreadScratch, NumThreadScratch> src_thread_scratch_tuple_;
|
||||
StaticallyIndexedArray<SrcOOBThreadScratch, NumThreadScratch> src_oob_thread_scratch_tuple_;
|
||||
|
||||
DstThreadScratch dst_thread_scratch_;
|
||||
|
||||
SrcCoord src_coord_;
|
||||
DstCoord dst_coord_;
|
||||
const SrcElementwiseOperation src_element_op_;
|
||||
const DstElementwiseOperation dst_element_op_;
|
||||
StaticallyIndexedArray<index_t, gather_num> gather_offsets_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,739 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_description/tensor_space_filling_curve.hpp"
|
||||
#include "ck/utility/is_detected.hpp"
|
||||
#include "ck/tensor/static_tensor.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp"
|
||||
|
||||
namespace ck {
|
||||
// Thread-level multi-source, multi-destination tensor slice data movement
|
||||
// Assume:
|
||||
// 1. All sources and destinations are DynamicBuffer
|
||||
// 2. Same VectorDim and ScalerPerVector for all sources and destinations
|
||||
// 3. DstInMemOps are per destination tensor
|
||||
// 4. ThreadTransferSrcResetCoordinateAfterRunFlags are per source tensor
|
||||
// 5. ThreadTransferDstResetCoordinateAfterRunFlags are per destination tensor
|
||||
// 6. Does not need to know src_descs and dst_descs at compile-time
|
||||
// 7. Does not need to know src_slice_origins and dst_slice_origins at compile-time,
|
||||
//
|
||||
// Does following things to avoid scratch memory issue
|
||||
// 1. Use StaticallyIndexedArray or vector_type instead of C array for thread buffer
|
||||
// 2. Pass tensor descritpors by reference (or tuple of references)
|
||||
// 3. Does not keep reference to tensor descriptor
|
||||
// 4. Does not construct new tensor coordinate when call Run()
|
||||
template <typename SrcDatas,
|
||||
typename DstDatas,
|
||||
typename SrcDescs,
|
||||
typename DstDescs,
|
||||
typename ElementwiseOperation,
|
||||
typename DstInMemOps, // Sequence<InMemoryDataOperationEnum ...>
|
||||
typename SliceLengths,
|
||||
typename SrcDimAccessOrder,
|
||||
typename DstDimAccessOrder,
|
||||
index_t SrcVectorDim,
|
||||
index_t DstVectorDim,
|
||||
typename SrcScalarPerVectors,
|
||||
index_t DstScalarPerVector,
|
||||
typename SrcResetCoordinateAfterRunFlags, // Sequence<bool ...>
|
||||
typename DstResetCoordinateAfterRunFlags, // Sequence<bool ...>
|
||||
index_t ScatterDim = 1,
|
||||
bool OutputScatter = true,
|
||||
index_t ScatterWeightIdx = 3,
|
||||
index_t NumThreadScratch = 1>
|
||||
struct ThreadwiseTensorSliceTransfer_v7r3_scatter
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
static constexpr auto SrcScalarPerVector = SrcScalarPerVectors{}[I0];
|
||||
|
||||
static constexpr index_t nDim = SliceLengths::Size();
|
||||
|
||||
static constexpr index_t nSrc = SrcDescs::Size();
|
||||
static constexpr index_t nDst = DstDescs::Size();
|
||||
|
||||
using Index = MultiIndex<nDim>;
|
||||
static constexpr index_t scatter_num = SliceLengths{}.At(Number<ScatterDim>{});
|
||||
|
||||
// return a tuple of coordiantes for a tuple of tensor
|
||||
template <typename Descs,
|
||||
typename Indices,
|
||||
enable_if_t<Descs::Size() == Indices::Size(), bool> = false>
|
||||
static constexpr auto MakeCoordinates(const Descs& descs, const Indices& indices)
|
||||
{
|
||||
return generate_tuple([&](auto i) { return make_tensor_coordinate(descs[i], indices[i]); },
|
||||
Number<Descs::Size()>{});
|
||||
}
|
||||
|
||||
using SrcCoords = decltype(MakeCoordinates(SrcDescs{}, StaticallyIndexedArray<Index, nSrc>{}));
|
||||
using DstCoords = decltype(MakeCoordinates(DstDescs{}, StaticallyIndexedArray<Index, nDst>{}));
|
||||
|
||||
// scalar per access on each dim
|
||||
// FIXME: don't use lambda_scalar_per_access
|
||||
static constexpr auto src_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
static constexpr auto dst_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
using SrcSpaceFillingCurve = SpaceFillingCurve<SliceLengths,
|
||||
SrcDimAccessOrder,
|
||||
remove_cv_t<decltype(src_scalar_per_access)>,
|
||||
false>;
|
||||
|
||||
using DstSpaceFillingCurve = SpaceFillingCurve<SliceLengths,
|
||||
DstDimAccessOrder,
|
||||
remove_cv_t<decltype(dst_scalar_per_access)>,
|
||||
false>;
|
||||
|
||||
__device__ constexpr ThreadwiseTensorSliceTransfer_v7r3_scatter(
|
||||
const SrcDescs& src_descs,
|
||||
const StaticallyIndexedArray<Index, nSrc>& src_slice_origins,
|
||||
const DstDescs& dst_descs,
|
||||
const StaticallyIndexedArray<Index, nDst>& dst_slice_origins,
|
||||
const ElementwiseOperation& element_op)
|
||||
: src_coords_(MakeCoordinates(src_descs, src_slice_origins)),
|
||||
dst_coords_(MakeCoordinates(dst_descs, dst_slice_origins)),
|
||||
element_op_(element_op)
|
||||
{
|
||||
static_assert(SliceLengths::At(Number<SrcVectorDim>{}) % SrcScalarPerVector == 0,
|
||||
"wrong! cannot evenly divide");
|
||||
|
||||
static_assert(SliceLengths::At(Number<DstVectorDim>{}) % DstScalarPerVector == 0,
|
||||
"wrong! cannot evenly divide");
|
||||
}
|
||||
|
||||
template <typename Indices, enable_if_t<SrcDescs::Size() == Indices::Size(), bool> = false>
|
||||
__device__ void SetSrcSliceOrigins(const SrcDescs& src_descs,
|
||||
const Indices& src_slice_origin_idxs)
|
||||
{
|
||||
static_for<0, nSrc, 1>{}([&](auto i) {
|
||||
src_coords_(i) = make_tensor_coordinate(src_descs[i], src_slice_origin_idxs[i]);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename Indices, enable_if_t<DstDescs::Size() == Indices::Size(), bool> = false>
|
||||
__device__ void SetDstSliceOrigins(const DstDescs& dst_descs,
|
||||
const Indices& dst_slice_origin_idxs)
|
||||
{
|
||||
static_for<0, nDst, 1>{}([&](auto i) {
|
||||
dst_coords_(i) = make_tensor_coordinate(dst_descs[i], dst_slice_origin_idxs[i]);
|
||||
// printf("tid %d origin %d %d %d %d off %d\n", threadIdx.x,
|
||||
// dst_slice_origin_idxs[i][I0], dst_slice_origin_idxs[i][I1],
|
||||
// dst_slice_origin_idxs[i][I2], dst_slice_origin_idxs[i][I3],
|
||||
// dst_coords_(i).GetOffset());
|
||||
});
|
||||
}
|
||||
|
||||
template <typename DataTypes, index_t ScalarPerVector>
|
||||
__device__ static auto generate_vectors()
|
||||
{
|
||||
auto data_types = DataTypes{};
|
||||
|
||||
constexpr index_t num = data_types.Size();
|
||||
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
using DataType = remove_cvref_t<decltype(data_types[i])>;
|
||||
|
||||
return vector_type_maker_t<DataType, ScalarPerVector>{};
|
||||
},
|
||||
Number<num>{});
|
||||
}
|
||||
|
||||
// SrcDescs: Tuple<const SrcDesc0&, const SrcDesc1&, ...>
|
||||
// SrcBuffers: Tuple<const SrcBuffer0&, const SrcBuffer1&, ...>
|
||||
template <typename SrcBuffers,
|
||||
index_t ThreadScratchId = 0,
|
||||
enable_if_t<SrcDescs::Size() == SrcBuffers::Size(), bool> = false>
|
||||
__device__ void RunRead(const SrcDescs& src_descs,
|
||||
const SrcBuffers& src_bufs,
|
||||
StaticallyIndexedArray<float, scatter_num>& scatter_weights,
|
||||
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
// loop over space-filling curve
|
||||
static_for<0, src_num_access, 1>{}([&](auto iAccess) {
|
||||
auto src_vectors = generate_vectors<SrcDatas, SrcScalarPerVector>();
|
||||
auto elm_vectors = generate_vectors<DstDatas, SrcScalarPerVector>();
|
||||
|
||||
bool oob_val = true;
|
||||
|
||||
// copy data from src_bufs into src_vectors
|
||||
static_for<0, nSrc, 1>{}([&](auto i) {
|
||||
using src_vector_t = typename remove_cvref_t<decltype(src_vectors[i])>::type;
|
||||
|
||||
const bool is_src_valid =
|
||||
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_descs[i],
|
||||
src_coords_[i]);
|
||||
|
||||
oob_val = oob_val & is_src_valid;
|
||||
if(i.value == ScatterWeightIdx)
|
||||
{
|
||||
static_assert(SrcScalarPerVectors{}[Number<ScatterWeightIdx>{}] == 1,
|
||||
"scatter weight dim, should only one vec");
|
||||
constexpr auto iScatter =
|
||||
SrcSpaceFillingCurve::GetIndex(iAccess)(Number<ScatterDim>{});
|
||||
// if(threadIdx.x % 8 ==0 )
|
||||
// printf("bid %d tid %d srcid %d sv %f\n", blockIdx.y, threadIdx.x, i.value,
|
||||
// scatter_weights(Number<iScatter>{}));
|
||||
static_for<0, SrcScalarPerVector, 1>{}([&](auto j) {
|
||||
src_vectors(i).template AsType<float>()(j) =
|
||||
scatter_weights(Number<iScatter>{});
|
||||
});
|
||||
}
|
||||
else if constexpr(SrcScalarPerVectors{}[i] == 1)
|
||||
{
|
||||
auto data_types = SrcDatas{};
|
||||
using DataType = remove_cvref_t<decltype(data_types[i])>;
|
||||
const auto tmp =
|
||||
src_bufs[i].template Get<DataType>(src_coords_[i].GetOffset(), true);
|
||||
// if(threadIdx.x % 8 ==0 )
|
||||
// printf("bid %d tid %d srcid %d off %d v %f\n", blockIdx.y, threadIdx.x,
|
||||
// i.value, src_coords_[i].GetOffset(), tmp);
|
||||
static_for<0, SrcScalarPerVector, 1>{}(
|
||||
[&](auto j) { src_vectors(i).template AsType<DataType>()(j) = tmp; });
|
||||
}
|
||||
else
|
||||
{
|
||||
// if(threadIdx.x % 8 ==0 )
|
||||
// printf("bid %d tid %d srcid %d vn\n", blockIdx.y, threadIdx.x, i.value);
|
||||
src_vectors(i).template AsType<src_vector_t>()(I0) =
|
||||
src_bufs[i].template Get<src_vector_t>(src_coords_[i].GetOffset(), true);
|
||||
}
|
||||
});
|
||||
|
||||
constexpr auto get_elem_op_vec_len = []() {
|
||||
if constexpr(is_detected<is_pack8_invocable_t, decltype(element_op_)>::value)
|
||||
{
|
||||
if constexpr(decltype(element_op_)::is_pack8_invocable)
|
||||
return math::min(8, SrcScalarPerVector);
|
||||
}
|
||||
if constexpr(is_detected<is_pack4_invocable_t, decltype(element_op_)>::value)
|
||||
{
|
||||
if constexpr(decltype(element_op_)::is_pack4_invocable)
|
||||
return math::min(4, SrcScalarPerVector);
|
||||
}
|
||||
if constexpr(is_detected<is_pack2_invocable_t, decltype(element_op_)>::value)
|
||||
{
|
||||
if constexpr(decltype(element_op_)::is_pack2_invocable)
|
||||
return math::min(2, SrcScalarPerVector);
|
||||
}
|
||||
return 1;
|
||||
};
|
||||
|
||||
constexpr index_t elem_op_vec_len = get_elem_op_vec_len();
|
||||
|
||||
// apply pointwise function
|
||||
static_for<0, SrcScalarPerVector / elem_op_vec_len, 1>{}([&](auto i) {
|
||||
// get reference to src data
|
||||
const auto src_data_refs = generate_tie(
|
||||
// return type should be lvalue
|
||||
[&](auto iSrc) -> const auto& {
|
||||
using SrcData = remove_cvref_t<tuple_element_t<iSrc.value, SrcDatas>>;
|
||||
|
||||
using elem_op_vec_t = typename vector_type<SrcData, elem_op_vec_len>::type;
|
||||
|
||||
return src_vectors[iSrc].template AsType<elem_op_vec_t>()[i];
|
||||
},
|
||||
Number<nSrc>{});
|
||||
|
||||
// get reference to dst data
|
||||
auto dst_data_refs = generate_tie(
|
||||
// return type should be lvalue
|
||||
[&](auto iDst) -> auto& {
|
||||
using DstData = remove_cvref_t<tuple_element_t<iDst.value, DstDatas>>;
|
||||
|
||||
using elem_op_vec_t = typename vector_type<DstData, elem_op_vec_len>::type;
|
||||
|
||||
return elm_vectors(iDst).template AsType<elem_op_vec_t>()(i);
|
||||
},
|
||||
Number<nDst>{});
|
||||
|
||||
// apply pointwise function
|
||||
// pointwise function signature:
|
||||
// element_op_(dst_data_refs[I0],
|
||||
// dst_data_refs[I1],
|
||||
// ...,
|
||||
// src_data_refs[I0],
|
||||
// src_data_refs[I1],
|
||||
// ...)
|
||||
unpack2(element_op_, dst_data_refs, src_data_refs);
|
||||
});
|
||||
|
||||
elm_vectors_tuple_(thread_scratch_id)(iAccess) = elm_vectors;
|
||||
oob_vectors_tuple_(thread_scratch_id)(iAccess) = oob_val;
|
||||
|
||||
// move coordinate
|
||||
if constexpr(iAccess.value != src_num_access - 1)
|
||||
{
|
||||
constexpr auto forward_step = SrcSpaceFillingCurve::GetForwardStep(iAccess);
|
||||
|
||||
static_for<0, nSrc, 1>{}([&](auto i) {
|
||||
move_tensor_coordinate(src_descs[i],
|
||||
src_coords_(i),
|
||||
make_tensor_coordinate_step(src_descs[i], forward_step));
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
// move coordinate back to slice origin (or not)
|
||||
static_for<0, nSrc, 1>{}([&](auto i) {
|
||||
if constexpr(SrcResetCoordinateAfterRunFlags::At(i))
|
||||
{
|
||||
const auto src_reset_step =
|
||||
make_tensor_coordinate_step(src_descs[i], GetSrcCoordinateResetStep());
|
||||
|
||||
move_tensor_coordinate(src_descs[i], src_coords_(i), src_reset_step);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#if 1
|
||||
template <index_t ThreadScratchId = 0>
|
||||
__device__ void OOBCheck(Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
// loop over space-filling curve
|
||||
static_for<0, src_num_access, 1>{}([&](auto iAccess) {
|
||||
auto elm_vectors = elm_vectors_tuple_[thread_scratch_id][iAccess];
|
||||
auto oob_val = oob_vectors_tuple_[thread_scratch_id][iAccess];
|
||||
|
||||
static_for<0, nDst, 1>{}([&](auto i) {
|
||||
using elm_vector_t = typename remove_cvref_t<decltype(elm_vectors[i])>::type;
|
||||
elm_vectors(i).template AsType<elm_vector_t>()(I0) =
|
||||
oob_val ? elm_vectors(i).template AsType<elm_vector_t>()[I0] : elm_vector_t{0};
|
||||
});
|
||||
|
||||
elm_vectors_tuple_(thread_scratch_id)(iAccess) = elm_vectors;
|
||||
});
|
||||
}
|
||||
#endif
|
||||
|
||||
template <index_t ThreadScratchId = 0>
|
||||
__device__ void
|
||||
TransposeFromElmToDst(Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
using DstData = remove_cvref_t<decltype(DstDatas{}[I0])>;
|
||||
|
||||
using ElmThreadScratch =
|
||||
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
|
||||
DstData,
|
||||
SrcScalarPerVector,
|
||||
decltype(GetSrcThreadScratchDescriptor()),
|
||||
true>;
|
||||
using DstThreadScratch =
|
||||
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
|
||||
DstData,
|
||||
DstScalarPerVector,
|
||||
decltype(GetDstThreadScratchDescriptor()),
|
||||
true>;
|
||||
|
||||
ElmThreadScratch elm_thread_scratch_;
|
||||
DstThreadScratch dst_thread_scratch_;
|
||||
|
||||
elm_thread_scratch_.data_ =
|
||||
bit_cast<decltype(elm_thread_scratch_.data_)>(elm_vectors_tuple_[thread_scratch_id]);
|
||||
|
||||
if constexpr(SrcVectorDim != DstVectorDim &&
|
||||
((is_same<half_t, remove_cvref_t<DstData>>::value &&
|
||||
SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0) ||
|
||||
(is_same<f8_t, remove_cvref_t<DstData>>::value &&
|
||||
SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0) ||
|
||||
(is_same<int8_t, remove_cvref_t<DstData>>::value &&
|
||||
SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0)))
|
||||
{
|
||||
// each transpose does
|
||||
// DstScalarPerVector # of src vectors in src_thread_scratch_
|
||||
// SrcScalarPerVector # of dst vectors in dst_thread_scratch_
|
||||
constexpr index_t num_src_vector = Number<DstScalarPerVector>{};
|
||||
constexpr index_t num_dst_vector = Number<SrcScalarPerVector>{};
|
||||
|
||||
// Assume SrcVectorDim is not the same as DstVectorDim, so we do transpose
|
||||
// TODO: make this logic generic for all scenario
|
||||
|
||||
constexpr auto src_scalar_step_in_vector = generate_sequence(
|
||||
detail::lambda_scalar_step_in_vector<SrcVectorDim>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto dst_scalar_step_in_vector = generate_sequence(
|
||||
detail::lambda_scalar_step_in_vector<DstVectorDim>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access_for_src_and_dst<SrcVectorDim,
|
||||
SrcScalarPerVector,
|
||||
DstVectorDim,
|
||||
DstScalarPerVector>{},
|
||||
Number<nDim>{});
|
||||
|
||||
constexpr auto access_lengths = SliceLengths{} / scalar_per_access;
|
||||
|
||||
static_ford<decltype(access_lengths)>{}([&](auto access_idx) {
|
||||
constexpr auto data_idx = access_idx * scalar_per_access;
|
||||
|
||||
constexpr auto data_idx_seq = generate_sequence_v2(
|
||||
[&](auto i) { return Number<data_idx[i]>{}; }, Number<nDim>{});
|
||||
|
||||
using src_vector_t = vector_type_maker_t<DstData, SrcScalarPerVector>;
|
||||
using dst_vector_t = vector_type_maker_t<DstData, DstScalarPerVector>;
|
||||
|
||||
// get DstScalarPerVector # of read-only references to src vectors from
|
||||
// src_thread_scratch_
|
||||
const auto src_vector_refs = generate_tie(
|
||||
[&](auto i) -> const src_vector_t& {
|
||||
// i increment corresponds to movement in DstVectorDim
|
||||
return elm_thread_scratch_.GetVectorTypeReference(
|
||||
data_idx_seq + i * dst_scalar_step_in_vector);
|
||||
},
|
||||
Number<num_src_vector>{});
|
||||
|
||||
// get SrcScalarPerVector # of references to dst vectors from
|
||||
// dst_thread_scratch_
|
||||
auto dst_vector_refs = generate_tie(
|
||||
[&](auto i) -> dst_vector_t& {
|
||||
// i increment corresponds to movement in SrcVectorDim
|
||||
return dst_thread_scratch_.GetVectorTypeReference(
|
||||
data_idx_seq + i * src_scalar_step_in_vector);
|
||||
},
|
||||
Number<num_dst_vector>{});
|
||||
|
||||
// do data transpose
|
||||
transpose_vectors<DstData, DstScalarPerVector, SrcScalarPerVector>{}(
|
||||
src_vector_refs, dst_vector_refs);
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_ford<SliceLengths>{}(
|
||||
[&](auto idx) { dst_thread_scratch_(idx) = elm_thread_scratch_[idx]; });
|
||||
}
|
||||
|
||||
dst_vectors_tuple_(thread_scratch_id) = bit_cast<DstVectorTuple>(dst_thread_scratch_.data_);
|
||||
}
|
||||
|
||||
// DstDescs: Tuple<const DstDesc0&, const DstDesc1&, ...>
|
||||
// DstBuffers: Tuple<const DstBuffer0&, const DstBuffer1&, ...>
|
||||
template <typename DstBuffers,
|
||||
index_t ThreadScratchId = 0,
|
||||
enable_if_t<DstDescs::Size() == 1 && DstBuffers::Size() == 1, bool> = false>
|
||||
__device__ void RunWrite(const DstDescs& dst_descs,
|
||||
DstBuffers dst_bufs,
|
||||
StaticallyIndexedArray<index_t, scatter_num>& scatter_offsets,
|
||||
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
OOBCheck(thread_scratch_id);
|
||||
TransposeFromElmToDst(thread_scratch_id);
|
||||
|
||||
// loop over space-filling curve
|
||||
static_for<0, dst_num_access, 1>{}([&](auto iAccess) {
|
||||
auto dst_vectors = dst_vectors_tuple_[thread_scratch_id][iAccess];
|
||||
auto scatter_offset = 0;
|
||||
if constexpr(OutputScatter)
|
||||
{
|
||||
constexpr auto iScatter =
|
||||
DstSpaceFillingCurve::GetIndex(iAccess)(Number<ScatterDim>{});
|
||||
scatter_offset = scatter_offsets(Number<iScatter>{});
|
||||
}
|
||||
// copy data from buf_vectors into dst_bufs
|
||||
static_for<0, nDst, 1>{}([&](auto i) {
|
||||
using dst_vector_t = typename remove_cvref_t<decltype(dst_vectors[i])>::type;
|
||||
auto dst_offset = scatter_offset + dst_coords_[i].GetOffset();
|
||||
const bool is_dst_valid = dst_offset < dst_descs[i].GetElementSpaceSize();
|
||||
// coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_descs[i],
|
||||
// dst_coords_[i]);
|
||||
|
||||
constexpr InMemoryDataOperationEnum DstInMemOp =
|
||||
static_cast<InMemoryDataOperationEnum>(DstInMemOps::At(i.value));
|
||||
|
||||
// if(threadIdx.x==0)
|
||||
// printf("use tid %d off %d %d\n", threadIdx.x, dst_coords_[i].GetOffset(),
|
||||
// scatter_offset );
|
||||
dst_bufs(i).template Update<DstInMemOp, dst_vector_t>(
|
||||
dst_offset, is_dst_valid, dst_vectors[i].template AsType<dst_vector_t>()[I0]);
|
||||
// if(threadIdx.x%8 ==0 && blockIdx.x==0) {
|
||||
// static_for<0, 1, 1>{}([&](auto idx) {
|
||||
// using DstData = remove_cvref_t<tuple_element_t<0, DstDatas>>;
|
||||
// using print_vec_t = typename vector_type<DstData, 1>::type;
|
||||
// printf("tid %d off %d valid %d %f\n",threadIdx.x, dst_offset,
|
||||
// is_dst_valid, type_convert<float>(dst_vectors[i].template
|
||||
// AsType<print_vec_t>()[idx]));
|
||||
// });
|
||||
// }
|
||||
});
|
||||
|
||||
// move coordinate
|
||||
if constexpr(iAccess.value != dst_num_access - 1)
|
||||
{
|
||||
constexpr auto forward_step = DstSpaceFillingCurve::GetForwardStep(iAccess);
|
||||
|
||||
auto forward_step_scatter = [&]() constexpr
|
||||
{
|
||||
Index step_;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
step_(i) = (i.value == ScatterDim && OutputScatter) ? 0 : forward_step[i];
|
||||
|
||||
// if(threadIdx.x==0)
|
||||
// printf("i %d %d ordered_gather_dim %d\n", i.value, step_(i),
|
||||
// ordered_gather_dim);
|
||||
});
|
||||
|
||||
return step_;
|
||||
}
|
||||
();
|
||||
static_for<0, nDst, 1>{}([&](auto i) {
|
||||
move_tensor_coordinate(
|
||||
dst_descs[i],
|
||||
dst_coords_(i),
|
||||
make_tensor_coordinate_step(dst_descs[i], forward_step_scatter));
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
static_for<0, nDst, 1>{}([&](auto i) {
|
||||
if constexpr(DstResetCoordinateAfterRunFlags::At(i))
|
||||
{
|
||||
const auto dst_reset_step =
|
||||
make_tensor_coordinate_step(dst_descs[i], GetDstCoordinateResetStep());
|
||||
|
||||
move_tensor_coordinate(dst_descs[i], dst_coords_(i), dst_reset_step);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// SrcDescs: Tuple<const SrcDesc0&, const SrcDesc1&, ...>
|
||||
// SrcBuffers: Tuple<const SrcBuffer0&, const SrcBuffer1&, ...>
|
||||
// DstDescs: Tuple<const DstDesc0&, const DstDesc1&, ...>
|
||||
// DstBuffers: Tuple<const DstBuffer0&, const DstBuffer1&, ...>
|
||||
template <typename SrcBuffers,
|
||||
typename DstBuffers,
|
||||
enable_if_t<SrcDescs::Size() == SrcBuffers::Size() &&
|
||||
DstDescs::Size() == DstBuffers::Size(),
|
||||
bool> = false>
|
||||
__device__ void Run(const SrcDescs& src_descs,
|
||||
const SrcBuffers& src_bufs,
|
||||
const DstDescs& dst_descs,
|
||||
DstBuffers dst_bufs,
|
||||
StaticallyIndexedArray<index_t, scatter_num>& scatter_offsets,
|
||||
StaticallyIndexedArray<float, scatter_num>& scatter_weights)
|
||||
{
|
||||
RunRead(src_descs, src_bufs, scatter_weights);
|
||||
RunWrite(dst_descs, dst_bufs, scatter_offsets);
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetSrcCoordinateResetStep()
|
||||
{
|
||||
if constexpr(src_num_access == 0)
|
||||
{
|
||||
return typename SrcSpaceFillingCurve::Index{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return SrcSpaceFillingCurve::GetStepBetween(Number<src_num_access - 1>{}, Number<0>{});
|
||||
}
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetDstCoordinateResetStep()
|
||||
{
|
||||
if constexpr(dst_num_access == 0)
|
||||
{
|
||||
return typename DstSpaceFillingCurve::Index{};
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto reset_step =
|
||||
DstSpaceFillingCurve::GetStepBetween(Number<dst_num_access - 1>{}, Number<0>{});
|
||||
auto reset_step_scatter = [&]() constexpr
|
||||
{
|
||||
Index step_;
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
step_(i) =
|
||||
(i.value == ScatterDim && OutputScatter) ? 0 : reset_step[Number<i>{}];
|
||||
|
||||
// if(threadIdx.x==0)
|
||||
// printf("i %d %d ordered_gather_dim %d\n", i.value, step_(i),
|
||||
// ordered_gather_dim);
|
||||
});
|
||||
|
||||
return step_;
|
||||
}
|
||||
();
|
||||
return reset_step_scatter;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetSrcThreadScratchDescriptor()
|
||||
{
|
||||
// constexpr auto src_scalar_per_access = generate_sequence(
|
||||
// detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{},
|
||||
// Number<nDim>{});
|
||||
|
||||
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
|
||||
|
||||
constexpr auto src_access_lengths_and_vector_length = container_push_back(
|
||||
sequence_to_tuple_of_number(src_access_lengths), Number<SrcScalarPerVector>{});
|
||||
|
||||
// 1st stage of transforms
|
||||
constexpr auto desc0 =
|
||||
make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length);
|
||||
|
||||
// 2nd stage of transforms
|
||||
constexpr auto transforms = generate_tuple(
|
||||
[&](auto i) {
|
||||
if constexpr(i == SrcVectorDim)
|
||||
{
|
||||
return make_merge_transform_v3_division_mod(
|
||||
make_tuple(src_access_lengths_and_vector_length[i],
|
||||
src_access_lengths_and_vector_length[Number<nDim>{}]));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_pass_through_transform(src_access_lengths_and_vector_length[i]);
|
||||
}
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
constexpr auto low_dim_idss = generate_tuple(
|
||||
[&](auto i) {
|
||||
if constexpr(i == SrcVectorDim)
|
||||
{
|
||||
return Sequence<i.value, nDim>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return Sequence<i.value>{};
|
||||
}
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
constexpr auto up_dim_idss =
|
||||
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
|
||||
|
||||
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetDstThreadScratchDescriptor()
|
||||
{
|
||||
// 1st stage of transforms
|
||||
// constexpr auto dst_scalar_per_access = generate_sequence(
|
||||
// detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{},
|
||||
// Number<nDim>{});
|
||||
|
||||
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
|
||||
|
||||
constexpr auto dst_access_lengths_and_vector_length = container_push_back(
|
||||
sequence_to_tuple_of_number(dst_access_lengths), Number<DstScalarPerVector>{});
|
||||
|
||||
constexpr auto desc0 =
|
||||
make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length);
|
||||
|
||||
// 2nd stage of transforms
|
||||
constexpr auto transforms = generate_tuple(
|
||||
[&](auto i) {
|
||||
if constexpr(i == DstVectorDim)
|
||||
{
|
||||
return make_merge_transform_v3_division_mod(
|
||||
make_tuple(dst_access_lengths_and_vector_length[i],
|
||||
dst_access_lengths_and_vector_length[Number<nDim>{}]));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_pass_through_transform(dst_access_lengths_and_vector_length[i]);
|
||||
}
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
constexpr auto low_dim_idss = generate_tuple(
|
||||
[&](auto i) {
|
||||
if constexpr(i == DstVectorDim)
|
||||
{
|
||||
return Sequence<i.value, nDim>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return Sequence<i.value>{};
|
||||
}
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
constexpr auto up_dim_idss =
|
||||
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
|
||||
|
||||
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
|
||||
}
|
||||
|
||||
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
|
||||
template <index_t ISrc>
|
||||
__device__ void MoveSrcSliceWindow(const SrcDescs& src_descs,
|
||||
Number<ISrc> iSrc,
|
||||
const Index& src_slice_origin_step_idx)
|
||||
{
|
||||
// if src coord was not reset by RunRead(), then need to adjust the step here
|
||||
const auto adjusted_step_idx =
|
||||
SrcResetCoordinateAfterRunFlags::At(iSrc)
|
||||
? src_slice_origin_step_idx
|
||||
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
|
||||
|
||||
// is it OK to construct a new step every time?
|
||||
const auto adjusted_step = make_tensor_coordinate_step(src_descs[iSrc], adjusted_step_idx);
|
||||
|
||||
move_tensor_coordinate(src_descs[iSrc], src_coords_(iSrc), adjusted_step);
|
||||
}
|
||||
|
||||
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
|
||||
template <index_t IDst>
|
||||
__device__ void MoveDstSliceWindow(const DstDescs& dst_descs,
|
||||
Number<IDst> iDst,
|
||||
const Index& dst_slice_origin_step_idx)
|
||||
{
|
||||
// if dst coord was not reset by Run(), then need to adjust the step here
|
||||
const auto adjusted_step_idx =
|
||||
DstResetCoordinateAfterRunFlags::At(iDst)
|
||||
? dst_slice_origin_step_idx
|
||||
: dst_slice_origin_step_idx + GetDstCoordinateResetStep();
|
||||
|
||||
auto adjusted_step_idx_scatter = [&]() {
|
||||
Index step_;
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
step_(i) =
|
||||
(i.value == ScatterDim && OutputScatter) ? 0 : adjusted_step_idx[Number<i>{}];
|
||||
});
|
||||
|
||||
return step_;
|
||||
}();
|
||||
// is it OK to construct a new step every time?
|
||||
const auto adjusted_step =
|
||||
make_tensor_coordinate_step(dst_descs[iDst], adjusted_step_idx_scatter);
|
||||
|
||||
move_tensor_coordinate(dst_descs[iDst], dst_coords_(iDst), adjusted_step);
|
||||
}
|
||||
|
||||
private:
|
||||
using SrcVectorsType = decltype(generate_vectors<SrcDatas, SrcScalarPerVector>());
|
||||
using ElmVectorsType = decltype(generate_vectors<DstDatas, SrcScalarPerVector>());
|
||||
using DstVectorsType = decltype(generate_vectors<DstDatas, DstScalarPerVector>());
|
||||
|
||||
static constexpr auto src_num_access = SrcSpaceFillingCurve::GetNumOfAccess();
|
||||
static constexpr auto dst_num_access = DstSpaceFillingCurve::GetNumOfAccess();
|
||||
|
||||
using ElmVectorTuple = StaticallyIndexedArray<ElmVectorsType, src_num_access>;
|
||||
using DstVectorTuple = StaticallyIndexedArray<DstVectorsType, dst_num_access>;
|
||||
|
||||
StaticallyIndexedArray<ElmVectorTuple, NumThreadScratch> elm_vectors_tuple_;
|
||||
StaticallyIndexedArray<DstVectorTuple, NumThreadScratch> dst_vectors_tuple_;
|
||||
|
||||
using OOBVectorTuple = StaticallyIndexedArray<bool, src_num_access>;
|
||||
StaticallyIndexedArray<OOBVectorTuple, NumThreadScratch> oob_vectors_tuple_;
|
||||
|
||||
SrcCoords src_coords_;
|
||||
DstCoords dst_coords_;
|
||||
const ElementwiseOperation element_op_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
Reference in New Issue
Block a user