ck moe gemm implement (#1936)

* port all moe changes from ck_moe_gemm branch

* refine codes in the pr

* fix tail odd

* fix clang format

* fix clang format2

* make hot loop scheduler compatible with 16x16 and 32x32

* clang format

* fix per token quant

* rename moe example

* clang format

---------

Co-authored-by: coderfeli <coderfeli@163.com>
This commit is contained in:
feli
2025-03-05 15:56:55 +08:00
committed by GitHub
parent c95bda93ba
commit 3786e16375
13 changed files with 6144 additions and 7 deletions

View File

@@ -141,6 +141,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
using Base::AMmaKStride;
using Base::BMmaKStride;
using Base::MWaves;
static constexpr index_t PrefetchStages = 2;
static constexpr index_t PrefillStages = 1;
@@ -184,12 +185,19 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
{
constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num;
constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num;
constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num * MWaves;
constexpr auto mfma_interleave = MPerXDL == 32 ? 1 : 2;
// B global
static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr(MPerBlock >= 128 && NPerBlock >= 128)
{
__builtin_amdgcn_sched_group_barrier(0x008, 2 * mfma_interleave, 0);
}
else
{
__builtin_amdgcn_sched_group_barrier(0x008, mfma_interleave, 0);
}
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
});
@@ -203,10 +211,10 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
});
// A local
static_for<0, num_ds_read_inst_a / 2, 1>{}([&](auto i) {
static_for<0, num_ds_read_inst_a / 2 * mfma_interleave, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 2 / mfma_interleave, 0); // DS read
});
}
@@ -320,7 +328,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type;

View File

@@ -0,0 +1,199 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_gather.hpp"
namespace ck {
/**
* @brief Blockwise data transfer
*
* This version does following things to avoid scratch memory issue
* 1. Use StaticallyIndexedArray instead of C array for thread buffer
* 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
* 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
*
*/
template <typename ThreadGroup,
typename SrcElementwiseOperation,
typename DstElementwiseOperation,
InMemoryDataOperationEnum DstInMemOp,
typename BlockSliceLengths,
typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder,
typename SrcData,
typename DstData,
typename SrcDesc,
typename DstDesc,
typename SrcDimAccessOrder,
typename DstDimAccessOrder,
index_t SrcVectorDim,
index_t DstVectorDim,
index_t SrcScalarPerVector,
index_t DstScalarPerVector,
index_t SrcScalarStrideInVector,
index_t DstScalarStrideInVector,
bool ThreadTransferSrcResetCoordinateAfterRun,
bool ThreadTransferDstResetCoordinateAfterRun,
index_t GatherDim = 1,
index_t NumThreadScratch = 1>
struct ThreadGroupTensorSliceTransfer_v4r1_gather
{
static constexpr auto I0 = Number<0>{};
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{};
static constexpr index_t gather_num = thread_slice_lengths.At(Number<GatherDim>{});
using Index = MultiIndex<nDim>;
__device__ constexpr ThreadGroupTensorSliceTransfer_v4r1_gather(
const SrcDesc& src_desc,
const Index& src_block_slice_origin,
const SrcElementwiseOperation& src_element_op,
const DstDesc& dst_desc,
const Index& dst_block_slice_origin,
const DstElementwiseOperation& dst_element_op,
const StaticallyIndexedArray<index_t, gather_num>& gather_offsets)
: threadwise_transfer_(src_desc,
make_zero_multi_index<nDim>(),
src_element_op,
dst_desc,
make_zero_multi_index<nDim>(),
dst_element_op,
gather_offsets)
{
static_assert(nDim == remove_cvref_t<SrcDesc>::GetNumOfDimension() &&
nDim == remove_cvref_t<DstDesc>::GetNumOfDimension() &&
nDim == ThreadClusterLengths::Size() &&
nDim == ThreadClusterArrangeOrder::Size() &&
nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(),
"wrong! nDim not consistent");
static_assert(
is_same<BlockSliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window");
static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
"wrong! ThreadGroup::GetNumOfThread() too small");
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(ThreadGroup::GetThreadId()));
threadwise_transfer_.SetSrcSliceOrigin(
src_desc, src_block_slice_origin + thread_cluster_idx * thread_slice_lengths);
threadwise_transfer_.SetDstSliceOrigin(
dst_desc, dst_block_slice_origin + thread_cluster_idx * thread_slice_lengths);
}
}
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_block_slice_origin)
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(ThreadGroup::GetThreadId()));
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
threadwise_transfer_.SetSrcSliceOrigin(src_desc,
src_block_slice_origin + thread_data_idx_begin);
}
}
template <typename SeqIdx, index_t ThreadScratchId = 0>
__device__ constexpr auto GetSrcThreadScratchIdx()
{
return threadwise_transfer_.template GetSrcThreadScratchIdx<SeqIdx, ThreadScratchId>();
}
template <typename SrcBuffer, index_t ThreadScratchId = 0>
__device__ void RunRead(const SrcDesc& src_desc,
const SrcBuffer& src_buf,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.RunRead(src_desc, src_buf, thread_scratch_id);
}
}
template <typename DstBuffer, index_t ThreadScratchId = 0>
__device__ void RunWrite(const DstDesc& dst_desc,
DstBuffer& dst_buf,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.RunWrite(dst_desc, dst_buf, thread_scratch_id);
}
}
template <typename SrcBuffer, typename DstBuffer, index_t ThreadScratchId>
__device__ void Run(const SrcDesc& src_desc,
const SrcBuffer& src_buf,
const DstDesc& dst_desc,
DstBuffer& dst_buf,
Number<ThreadScratchId> thread_scratch_id)
{
RunRead(src_desc, src_buf, thread_scratch_id);
RunWrite(dst_desc, dst_buf, thread_scratch_id);
}
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step)
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveSrcSliceWindow(src_desc, step);
}
}
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
}
}
private:
static constexpr auto thread_cluster_desc_ =
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
using ThreadwiseTransfer =
ThreadwiseTensorSliceTransfer_v3r1_gather<decltype(thread_slice_lengths),
SrcElementwiseOperation,
DstElementwiseOperation,
DstInMemOp,
SrcData,
DstData,
SrcDesc,
DstDesc,
SrcDimAccessOrder,
DstDimAccessOrder,
SrcVectorDim,
DstVectorDim,
SrcScalarPerVector,
DstScalarPerVector,
SrcScalarStrideInVector,
DstScalarStrideInVector,
ThreadTransferSrcResetCoordinateAfterRun,
ThreadTransferDstResetCoordinateAfterRun,
GatherDim,
NumThreadScratch>;
ThreadwiseTransfer threadwise_transfer_;
};
} // namespace ck

View File

@@ -0,0 +1,241 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp"
#include "ck/utility/is_detected.hpp"
namespace ck {
// Thread-group level multi-source, multi-destination tensor slice data movement
// Assume:
// 1. All sources and destinations are DynamicBuffer
// 2. Same VectorDim and ScalerPerVector for all sources and destinations
// 3. DstInMemOps are per destination tensor
// 4. ThreadTransferSrcResetCoordinateAfterRunFlags are per source tensor
// 5. ThreadTransferDstResetCoordinateAfterRunFlags are per destination tensor
//
// Does following things to avoid scratch memory issue
// 1. Pass tensor descritpors by reference (or tuple of references)
// 2. Does not keep reference to tensor descriptor
// 3. Does not construct new tensor coordinate when call Run()
template <typename ThreadGroup,
typename SrcDatas,
typename DstDatas,
typename SrcDescs,
typename DstDescs,
typename ElementwiseOperation,
typename DstInMemOps, // Sequence<InMemoryDataOperationEnum ...>
typename SliceLengths,
typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder,
typename SrcDimAccessOrder,
typename DstDimAccessOrder,
index_t SrcVectorDim,
index_t DstVectorDim,
typename SrcScalarPerVectors,
index_t DstScalarPerVector,
typename ThreadTransferSrcResetCoordinateAfterRunFlags,
typename ThreadTransferDstResetCoordinateAfterRunFlags,
index_t ScatterDim = 1,
bool OutputScatter = true,
index_t ScatterWeightIdx = 3,
index_t NumThreadScratch = 1>
struct ThreadGroupTensorSliceTransfer_v7r3_scatter
{
static constexpr index_t nDim =
remove_cvref_t<tuple_element_t<0, SrcDescs>>::GetNumOfDimension();
static constexpr index_t mod_num =
ThreadClusterLengths{}.At(Number<3>{}); // Dirty HACK FELIX, TODO fix
static constexpr index_t nSrc = remove_cvref_t<SrcDescs>::Size();
static constexpr index_t nDst = remove_cvref_t<DstDescs>::Size();
using Index = MultiIndex<nDim>;
static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{};
static constexpr index_t scatter_num = thread_slice_lengths.At(Number<ScatterDim>{});
__device__ constexpr ThreadGroupTensorSliceTransfer_v7r3_scatter(
const SrcDescs& src_descs,
const StaticallyIndexedArray<Index, nSrc>& src_block_slice_origins,
const DstDescs& dst_descs,
const StaticallyIndexedArray<Index, nDst>& dst_block_slice_origins,
const ElementwiseOperation& element_op)
: threadwise_transfer_(src_descs,
StaticallyIndexedArray<Index, nSrc>{},
dst_descs,
StaticallyIndexedArray<Index, nDst>{},
element_op)
{
static_assert(nSrc == SrcDatas::Size() && nSrc == SrcDescs::Size() &&
nSrc == ThreadTransferSrcResetCoordinateAfterRunFlags::Size() &&
nDst == DstDatas::Size() && nDst == DstDescs::Size() &&
nDst == ThreadTransferDstResetCoordinateAfterRunFlags::Size(),
"wrong!");
static_for<0, nSrc, 1>{}([&](auto i) {
static_assert(
nDim == remove_cvref_t<tuple_element_t<i.value, SrcDescs>>::GetNumOfDimension(),
"wrong!");
});
static_for<0, nDst, 1>{}([&](auto i) {
static_assert(
nDim == remove_cvref_t<tuple_element_t<i.value, DstDescs>>::GetNumOfDimension(),
"wrong!");
});
static_assert(nDim == ThreadClusterLengths::Size() &&
nDim == ThreadClusterArrangeOrder::Size() &&
nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(),
"wrong! nDim not consistent");
static_assert(
is_same<SliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window");
static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
"wrong! ThreadGroup::GetNumOfThread() too small");
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
const auto src_thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(ThreadGroup::GetThreadId()));
const auto src_thread_slice_origins = generate_tuple(
[&](auto i) {
return src_block_slice_origins[i] +
src_thread_cluster_idx * thread_slice_lengths;
},
Number<nSrc>{});
const auto dst_thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(OutputScatter ? ThreadGroup::GetThreadId() % mod_num
: ThreadGroup::GetThreadId()));
const auto dst_thread_slice_origins = generate_tuple(
[&](auto i) {
return dst_block_slice_origins[i] +
dst_thread_cluster_idx * thread_slice_lengths;
},
Number<nDst>{});
threadwise_transfer_.SetSrcSliceOrigins(src_descs, src_thread_slice_origins);
threadwise_transfer_.SetDstSliceOrigins(dst_descs, dst_thread_slice_origins);
}
}
template <typename SrcBuffers, index_t ThreadScratchId = 0>
__device__ void RunRead(const SrcDescs& src_descs,
const SrcBuffers& src_bufs,
StaticallyIndexedArray<float, scatter_num>& scatter_weights,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.RunRead(src_descs, src_bufs, scatter_weights, thread_scratch_id);
}
}
template <typename T>
using is_tuple = decltype(std::declval<T&>().IsTuple());
template <typename DstBuffers, index_t ThreadScratchId = 0>
__device__ void RunWrite(const DstDescs& dst_descs,
DstBuffers dst_bufs,
StaticallyIndexedArray<index_t, scatter_num>& scatter_offsets,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
if constexpr(is_detected<is_tuple, decltype(dst_bufs)>::value)
threadwise_transfer_.RunWrite(
dst_descs, dst_bufs, scatter_offsets, thread_scratch_id);
else
threadwise_transfer_.RunWrite(
dst_descs, tie(dst_bufs), scatter_offsets, thread_scratch_id);
}
}
template <typename SrcBuffers, typename DstBuffers>
__device__ void Run(const SrcDescs& src_descs,
const SrcBuffers& src_bufs,
const DstDescs& dst_descs,
DstBuffers dst_bufs,
StaticallyIndexedArray<index_t, scatter_num>& scatter_offsets,
StaticallyIndexedArray<float, scatter_num>& scatter_weights)
{
RunRead(src_descs, src_bufs, scatter_weights);
RunWrite(dst_descs, dst_bufs, scatter_offsets);
}
template <index_t ISrc>
__device__ void
MoveSrcSliceWindow(const SrcDescs& src_descs, Number<ISrc> iSrc, const Index& step)
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveSrcSliceWindow(src_descs, iSrc, step);
}
}
__device__ void MoveSrcSliceWindow(const SrcDescs& src_descs, const Index& step)
{
static_for<0, SrcDescs::Size(), 1>{}(
[&](auto i) { MoveSrcSliceWindow(src_descs, i, step); });
}
template <index_t IDst>
__device__ void
MoveDstSliceWindow(const DstDescs& dst_descs, Number<IDst> iDst, const Index& step)
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveDstSliceWindow(dst_descs, iDst, step);
}
}
__device__ void MoveDstSliceWindow(const DstDescs& dst_descs, const Index& step)
{
static_for<0, DstDescs::Size(), 1>{}(
[&](auto i) { MoveDstSliceWindow(dst_descs, i, step); });
}
private:
static constexpr auto thread_cluster_desc_ =
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
using ThreadwiseTransfer =
ThreadwiseTensorSliceTransfer_v7r3_scatter<SrcDatas,
DstDatas,
SrcDescs,
DstDescs,
ElementwiseOperation,
DstInMemOps,
decltype(thread_slice_lengths),
SrcDimAccessOrder,
DstDimAccessOrder,
SrcVectorDim,
DstVectorDim,
SrcScalarPerVectors,
DstScalarPerVector,
ThreadTransferSrcResetCoordinateAfterRunFlags,
ThreadTransferDstResetCoordinateAfterRunFlags,
ScatterDim,
OutputScatter,
ScatterWeightIdx,
NumThreadScratch>;
ThreadwiseTransfer threadwise_transfer_;
};
} // namespace ck