Moe gemm activation (#2026)

* fix useless code and remove usless oob

* clang format

* fix coredump in e2e test

* fix2

* fix clang format

* fix output oob

* impl int64 but result not correct

* int64 index ok now

* input output all ok

* fix uint32

* revert v1 test

* use uint32

* mork to support 13w tokens

* moe sorting fix moebuf

* fix merge

* update moe api fix aiter build

* fix buid

* fuse silu

* silu ok

* acale ok

* add silu

* change code

* gemm2 ok

* gufusion compatible ok, fix warnings

* gu fusion for m32 m64 ok

* support bf16 cshuffle

* i4 gemm2 ok

* i4 gemm2 ok and i4 gemm1 build

* 16x16 run ok

* change flops; change cshuffle dtype

* fuse gelu silu act in moe gemm1

* fp8 with act ready

* int4 act ready

* remove useless changes

* remove useless code change

* fix clang format

* add the arch limit of int4 moe gemm

* fuse moe activation

* fix fp8 16x16

* fix no quant case

* fix bugs

* fix fp8 gufusion bug

* remove useless comments

* refine activation code & complete moe example

* fix int8 bugs

* merge tkw1

---------

Co-authored-by: coderfeli <coderfeli@163.com>
Co-authored-by: feli <felix.li@amd.com>
Co-authored-by: illsilin <Illia.Silin@amd.com>
Co-authored-by: root <root@hjbog-srdc-51.amd.com>
Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
lalala-sh
2025-04-23 10:35:34 +08:00
committed by GitHub
parent 94662b02d0
commit 39ba03f25d
19 changed files with 1975 additions and 496 deletions

View File

@@ -12,7 +12,7 @@
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_gather.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp"
@@ -26,12 +26,17 @@ namespace ck {
// two lds chunks.
// 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds
// buffer when we declare __shared__ inside blkgemmpipe
enum Activation
{
gelu_and_mul = 0,
silu_and_mul = 1
};
template <typename GridwiseGemm,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
index_t MinimumOccupancy = 1,
bool IsInputGemm = false,
bool MulRoutedWeight = true,
TailNumber TailNum = TailNumber::Even>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
@@ -45,22 +50,19 @@ __global__ void
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
GridwiseGemm::template Run<HasMainKBlockLoop,
CGlobalMemoryDataOperation,
IsInputGemm,
MulRoutedWeight,
TailNum>(karg.p_sorted_token_ids,
karg.p_sorted_expert_ids,
karg.p_max_token_id,
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_ds_grid,
karg.p_c_grid,
p_shared,
karg,
karg.a_element_op,
karg.b_element_op,
karg.c_element_op);
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_sorted_token_ids,
karg.p_sorted_expert_ids,
karg.p_max_token_id,
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_ds_grid,
karg.p_c_grid,
p_shared,
karg,
karg.a_element_op,
karg.b_element_op,
karg.c_element_op);
#else
ignore = karg;
#endif // end of if (defined(__gfx9__))
@@ -70,8 +72,6 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
index_t MinimumOccupancy = 1,
bool IsInputGemm = false,
bool MulRoutedWeight = true,
TailNumber TailNum = TailNumber::Even>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
@@ -86,23 +86,20 @@ __global__ void
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop,
CGlobalMemoryDataOperation,
IsInputGemm,
MulRoutedWeight,
TailNum>(karg.p_sorted_token_ids,
karg.p_sorted_expert_ids,
karg.p_max_token_id,
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_ds_grid,
karg.p_c_grid,
p_shared,
p_shared1,
karg,
karg.a_element_op,
karg.b_element_op,
karg.c_element_op);
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_sorted_token_ids,
karg.p_sorted_expert_ids,
karg.p_max_token_id,
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_ds_grid,
karg.p_c_grid,
p_shared,
p_shared1,
karg,
karg.a_element_op,
karg.b_element_op,
karg.c_element_op);
#else
ignore = karg;
#endif // end of if (defined(__gfx9__))
@@ -154,7 +151,12 @@ template <typename ALayout,
typename CDEShuffleBlockTransferScalarPerVectors,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
index_t ActivationOperation = 0,
bool NSwizzle = false,
bool IsInputGemm = true,
bool MulRoutedWeight = true,
bool PerTokenQuant = false,
typename IndexType = index_t,
typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA,
typename LDSTypeA = ADataType,
@@ -227,6 +229,7 @@ struct GridwiseMoeGemm
const index_t mblock = math::integer_divide_ceil(M, MPerBlock);
const index_t gridx = NSwizzle ? nblock * mblock : nblock;
const index_t gridy = NSwizzle ? 1 : mblock;
return std::make_tuple(gridx, gridy, 1);
}
@@ -305,7 +308,7 @@ struct GridwiseMoeGemm
}
__host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
{
const auto a_grid_desc_mraw_kraw = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
@@ -497,8 +500,8 @@ struct GridwiseMoeGemm
}
template <typename ELayout>
__host__ __device__ static auto
MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
__host__ __device__ static auto MakeCGridDescriptor_M_N(
IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
{
const auto c_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ELayout>::value)
@@ -909,7 +912,8 @@ struct GridwiseMoeGemm
NPerXdl,
MXdlPerWave,
NXdlPerWave,
KPack>())>;
KPack,
IsInputGemm>())>;
__device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
@@ -1141,9 +1145,7 @@ struct GridwiseMoeGemm
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
bool IsInputGemm = true,
bool MulRoutedWeight = true,
TailNumber TailNum = TailNumber::Odd>
TailNumber TailNum = TailNumber::Odd>
__device__ static void Run(const index_t* p_sorted_token_ids,
const index_t* p_sorted_expert_ids,
const index_t* p_max_token_id,
@@ -1203,6 +1205,7 @@ struct GridwiseMoeGemm
return {blockIdx.x, blockIdx.y};
}
}();
const index_t block_n_id = block_mn.first;
const index_t block_m_id = block_mn.second;
const index_t token0 =
@@ -1218,7 +1221,7 @@ struct GridwiseMoeGemm
if(token_pos >= max_token_id || token0 >= problem.NumTokens)
return;
StaticallyIndexedArray<index_t, AMRepeats> gather_offsets;
StaticallyIndexedArray<IndexType, AMRepeats> gather_offsets;
static_for<0, AMRepeats, 1>{}([&](auto m0) {
const index_t fused_token = p_sorted_token_ids[token_pos + m0];
index_t token_offset = fused_token & 0xffffff;
@@ -1226,9 +1229,10 @@ struct GridwiseMoeGemm
{
token_offset = token_offset * problem.TopK + (fused_token >> 24);
}
gather_offsets(m0) = token_offset * problem.K;
gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
});
const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K);
const index_t expert_stride =
__builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
// N0, K0, Blocksize*KPack
const index_t n_block_data_idx_on_grid =
@@ -1239,7 +1243,6 @@ struct GridwiseMoeGemm
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid + expert_id * expert_stride / BPackedSize,
b_grid_desc_bpreshuffled.GetElementSpaceSize());
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
@@ -1269,6 +1272,7 @@ struct GridwiseMoeGemm
1,
AThreadTransferSrcResetCoordinateAfterRun,
true,
IndexType,
1,
BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
@@ -1311,24 +1315,74 @@ struct GridwiseMoeGemm
static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
decltype(c_thread_buf) c_thread_buf_up;
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
float,
c_thread_buf.num_of_v_,
c_thread_buf.s_per_v,
true>
c_thread_buf_fp32;
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
KPerBlock);
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
a_grid_buf,
a_block_buf,
a_block_slice_copy_step,
b_grid_desc_bpreshuffled,
b_blockwise_copy,
b_grid_buf,
b_block_buf,
b_block_slice_copy_step,
c_thread_buf,
num_k_block_main_loop);
if constexpr(IsInputGemm)
{
const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize;
const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid_up + expert_id * expert_stride / BPackedSize,
b_grid_desc_bpreshuffled.GetElementSpaceSize());
auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2<
BDataType,
BDataType,
decltype(b_grid_desc_bpreshuffled),
decltype(b_block_desc_bk0_n_bk1),
Sequence<Number<NXdlPerWave>{}, I1, Number<KRepeat>{}, Number<BK1Value>{}>,
Sequence<1, 2, 0, 3>,
3,
BBlockTransferSrcScalarPerVector,
BThreadTransferSrcResetCoordinateAfterRun,
true>(b_grid_desc_bpreshuffled,
make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id() % NWave,
0,
KPack * (get_thread_local_1d_id() % warpSize)));
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
a_grid_buf,
a_block_buf,
a_block_slice_copy_step,
b_grid_desc_bpreshuffled,
b_blockwise_copy,
b_blockwise_copy_up,
b_grid_buf,
b_grid_buf_up,
b_block_buf,
b_block_slice_copy_step,
c_thread_buf,
c_thread_buf_up,
num_k_block_main_loop);
}
else
{
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
a_grid_buf,
a_block_buf,
a_block_slice_copy_step,
b_grid_desc_bpreshuffled,
b_blockwise_copy,
b_grid_buf,
b_block_buf,
b_block_slice_copy_step,
c_thread_buf,
num_k_block_main_loop);
}
// shuffle C and write out
{
@@ -1356,6 +1410,185 @@ struct GridwiseMoeGemm
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
// mul scales
const float* p_sorted_weights_0 = p_ds_grid[I0];
const float* p_scale_b = p_ds_grid[I1];
static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock);
static_assert(M4 == 4);
const index_t m1 = get_warp_local_1d_id() / NWave;
const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl;
if(p_sorted_weights_0 != nullptr && p_scale_b != nullptr)
{
if constexpr(PerTokenQuant)
{
constexpr index_t scale_stride = (IsInputGemm ? 2 : 1);
p_scale_b += expert_id * problem.N * scale_stride + block_n_id * NPerBlock +
get_warp_local_1d_id() % NWave * NPerXdl + threadIdx.x % NPerXdl;
}
else
{
p_scale_b += expert_id;
}
vector_type<int32_t, 4> scale_token_ids;
vector_type<float, 4> topk_weights;
static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
const float scale_b = p_scale_b[n0 * NWave * NPerXdl * PerTokenQuant];
static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk
const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 +
m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
if constexpr(PerTokenQuant)
{
scale_token_ids =
*c_style_pointer_cast<const vector_type<int32_t, M4>*>(
p_sorted_token_ids + m_pos);
}
if constexpr(MulRoutedWeight)
{
topk_weights = *c_style_pointer_cast<const vector_type<float, M4>*>(
p_ds_grid[I2] + m_pos);
}
static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size
float scale_a = [&]() {
if constexpr(PerTokenQuant)
{
index_t fused_token = scale_token_ids.AsType<index_t>()[m4];
const index_t token_offset = fused_token & 0xffffff;
return token_offset < problem.NumTokens
? p_sorted_weights_0[token_offset]
: 0.0;
}
else
{
return p_sorted_weights_0[0];
}
}();
constexpr index_t c_offset =
blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
make_tuple(m0, n0, m2 * M4 + m4));
constexpr auto cidx = Number<c_offset>{};
if constexpr(IsInputGemm) // gu fusion
{
if constexpr(ActivationOperation == Activation::silu_and_mul)
{
const float scale_up =
p_scale_b[(n0 * NWave * NPerXdl + problem.N) *
PerTokenQuant];
float gate = scale_a * scale_b * c_thread_buf[cidx];
float up = scale_a * scale_up * c_thread_buf_up[cidx];
if constexpr(MulRoutedWeight)
{
gate = gate * topk_weights.AsType<float>()[m4];
up = up * topk_weights.AsType<float>()[m4];
}
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
{
gate *= 16;
up *= 16;
}
tensor_operation::element_wise::Silu{}(gate, gate);
c_thread_buf_fp32(cidx) = gate * up;
}
else if(ActivationOperation == Activation::gelu_and_mul)
{
const float scale_up =
p_scale_b[(n0 * NWave * NPerXdl + problem.N) *
PerTokenQuant];
float gate = scale_a * scale_b * c_thread_buf[cidx];
float up = scale_a * scale_up * c_thread_buf_up[cidx];
if constexpr(MulRoutedWeight)
{
gate = gate * topk_weights.AsType<float>()[m4];
up = up * topk_weights.AsType<float>()[m4];
}
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
{
gate *= 16;
up *= 16;
}
tensor_operation::element_wise::Gelu{}(gate, gate);
c_thread_buf_fp32(cidx) = gate * up;
}
}
else
{
c_thread_buf_fp32(cidx) =
scale_a * scale_b * c_thread_buf[cidx];
if constexpr(MulRoutedWeight)
{
c_thread_buf_fp32(cidx) = c_thread_buf_fp32(cidx) *
topk_weights.AsType<float>()[m4];
}
}
});
});
});
});
}
else
{
vector_type<float, 4> topk_weights; // for gemm2 only
static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk
const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 +
m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
if constexpr(MulRoutedWeight)
{
topk_weights = *c_style_pointer_cast<const vector_type<float, M4>*>(
p_ds_grid[I2] + m_pos);
}
static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size
constexpr index_t c_offset =
blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
make_tuple(m0, n0, m2 * M4 + m4));
constexpr auto cidx = Number<c_offset>{};
if constexpr(IsInputGemm) // gu fusion
{
if constexpr(ActivationOperation == Activation::silu_and_mul)
{
float gate = c_thread_buf[cidx];
float up = c_thread_buf_up[cidx];
if constexpr(MulRoutedWeight)
{
gate = gate * topk_weights.AsType<float>()[m4];
up = up * topk_weights.AsType<float>()[m4];
}
tensor_operation::element_wise::Silu{}(gate, gate);
c_thread_buf_fp32(cidx) = gate * up;
}
else if(ActivationOperation == Activation::gelu_and_mul)
{
float gate = c_thread_buf[cidx];
float up = c_thread_buf_up[cidx];
if constexpr(MulRoutedWeight)
{
gate = gate * topk_weights.AsType<float>()[m4];
up = up * topk_weights.AsType<float>()[m4];
}
tensor_operation::element_wise::Gelu{}(gate, gate);
c_thread_buf_fp32(cidx) = gate * up;
}
}
else
{
c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
if constexpr(MulRoutedWeight)
{
c_thread_buf_fp32(cidx) = topk_weights.AsType<float>()[m4] *
c_thread_buf_fp32[cidx];
}
}
});
});
});
});
}
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
@@ -1453,17 +1686,8 @@ struct GridwiseMoeGemm
const auto ds_grid_buf = generate_tuple(
[&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
const DDataType* ptr_ = p_ds_grid[i];
// hack logic here to support different kind of strides. todo fix it.
// ascale t, 1; bscale E, N, 1, move ptr to E
if(i.value == 1)
{
ptr_ +=
expert_id * (problem.StrideDs[1] ? problem.StrideDs[1] * problem.N : 1);
}
return make_dynamic_buffer<AddressSpaceEnum::Global>(
ptr_, ds_grid_desc_m_n[i].GetElementSpaceSize());
p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
},
Number<NumDTensor>{});
@@ -1526,7 +1750,8 @@ struct GridwiseMoeGemm
Sequence<true>,
uniform_sequence_gen_t<NumDTensor,
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
IndexType,
1, // ScatterDim
true, // OutputScatter: false, only use scatter weights
scatter_weight_idx // ScatterWeightIdx: ascale
@@ -1538,7 +1763,6 @@ struct GridwiseMoeGemm
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// space filling curve for threadwise C in VGPR
constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
@@ -1568,35 +1792,21 @@ struct GridwiseMoeGemm
constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
constexpr auto ENThreads =
CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
const float* p_sorted_weights_0 = p_ds_grid[I0];
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
StaticallyIndexedArray<index_t, EMRepeats> scatter_offsets;
StaticallyIndexedArray<float, EMRepeats> scatter_weights; //= for topk
StaticallyIndexedArray<IndexType, EMRepeats> scatter_offsets;
auto dstidx = sfc_cde_block.GetIndex(access_id);
const index_t c_token_pos =
block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
static_for<0, EMRepeats, 1>{}([&](auto m0) {
const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
index_t token_offset = fused_token & 0xffffff;
float weight = token_offset < problem.NumTokens
? p_sorted_weights_0[token_offset * problem.StrideDs[0]]
: 0.0;
IndexType token_offset = fused_token & 0xffffff;
if constexpr(IsInputGemm)
{
token_offset = token_offset * problem.TopK + (fused_token >> 24);
}
if constexpr(MulRoutedWeight)
{
const float* p_sorted_weights_2 = p_ds_grid[I2];
if constexpr(sizeof(ADataType) < 2)
weight = p_sorted_weights_2[c_token_pos + m0] * weight;
else
weight = p_sorted_weights_2[c_token_pos + m0];
}
scatter_offsets(m0) = token_offset * problem.N;
scatter_weights(m0) = weight;
scatter_offsets(m0) = static_cast<IndexType>(token_offset) * problem.N;
});
block_sync_lds();
@@ -1604,7 +1814,7 @@ struct GridwiseMoeGemm
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
c_thread_buf,
c_thread_buf_fp32,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_shuffle_block_buf);
@@ -1617,8 +1827,7 @@ struct GridwiseMoeGemm
c_ds_buf_refs,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
tie(c_grid_buf),
scatter_offsets,
scatter_weights);
scatter_offsets);
if constexpr(access_id < num_access - 1)
{
@@ -1643,9 +1852,7 @@ struct GridwiseMoeGemm
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
bool IsInputGemm = true,
bool MulRoutedWeight = true,
TailNumber TailNum = TailNumber::Odd>
TailNumber TailNum = TailNumber::Odd>
__device__ static void Run_2Lds(const index_t* p_sorted_token_ids,
const index_t* p_sorted_expert_ids,
const index_t* p_max_token_id,
@@ -1721,7 +1928,7 @@ struct GridwiseMoeGemm
if(token_pos >= max_token_id || expert_block_id * MPerBlock >= max_token_id ||
token0 >= problem.NumTokens)
return;
StaticallyIndexedArray<index_t, AMRepeats>
StaticallyIndexedArray<IndexType, AMRepeats>
gather_offsets; //= p_sorted_token_ids[token_pos];
static_for<0, AMRepeats, 1>{}([&](auto m0) {
const index_t fused_token = p_sorted_token_ids[token_pos + m0];
@@ -1730,7 +1937,7 @@ struct GridwiseMoeGemm
{
token_offset = token_offset * problem.TopK + (fused_token >> 24);
}
gather_offsets(m0) = token_offset * problem.K;
gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
});
const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K);
@@ -1773,6 +1980,7 @@ struct GridwiseMoeGemm
1,
AThreadTransferSrcResetCoordinateAfterRun,
true,
IndexType,
1,
2>(a_grid_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
@@ -1967,11 +2175,12 @@ struct GridwiseMoeGemm
const DDataType* ptr_ = p_ds_grid[i];
// hack logic here to support different kind of strides. todo fix it.
// ascale t, 1; bscale E, N, 1, move ptr to E
if(i.value == 1)
{
ptr_ +=
expert_id * (problem.StrideDs[1] ? problem.StrideDs[1] * problem.N : 1);
}
// if(i.value == 1)
// {
// ptr_ +=
// expert_id * (problem.StrideDs[1] ? problem.StrideDs[1] * problem.N :
// 1);
// }
return make_dynamic_buffer<AddressSpaceEnum::Global>(
ptr_, ds_grid_desc_m_n[i].GetElementSpaceSize());
},
@@ -2036,7 +2245,8 @@ struct GridwiseMoeGemm
Sequence<true>,
uniform_sequence_gen_t<NumDTensor,
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
IndexType,
1, // ScatterDim
true, // OutputScatter: false, only use scatter weights
scatter_weight_idx // ScatterWeightIdx: ascale
@@ -2078,12 +2288,9 @@ struct GridwiseMoeGemm
constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
constexpr auto ENThreads =
CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
const float* p_sorted_weights_0 = p_ds_grid[I0];
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
StaticallyIndexedArray<index_t, EMRepeats>
scatter_offsets; //= p_sorted_token_ids[c_token_pos];
StaticallyIndexedArray<float, EMRepeats> scatter_weights; //= for topk
StaticallyIndexedArray<IndexType, EMRepeats> scatter_offsets;
auto dstidx = sfc_cde_block.GetIndex(access_id);
const index_t c_token_pos =
@@ -2091,23 +2298,11 @@ struct GridwiseMoeGemm
static_for<0, EMRepeats, 1>{}([&](auto m0) {
const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
index_t token_offset = fused_token & 0xffffff;
float weight = token_offset < problem.NumTokens
? p_sorted_weights_0[token_offset * problem.StrideDs[0]]
: 0.0;
if constexpr(IsInputGemm)
{
token_offset = token_offset * problem.TopK + (fused_token >> 24);
}
if constexpr(MulRoutedWeight)
{
const float* p_sorted_weights_2 = p_ds_grid[I2];
if constexpr(sizeof(ADataType) < 2)
weight = p_sorted_weights_2[c_token_pos + m0] * weight;
else
weight = p_sorted_weights_2[c_token_pos + m0];
}
scatter_offsets(m0) = token_offset * problem.N;
scatter_weights(m0) = weight;
scatter_offsets(m0) = static_cast<IndexType>(token_offset) * problem.N;
});
block_sync_lds();
@@ -2128,8 +2323,7 @@ struct GridwiseMoeGemm
c_ds_buf_refs,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
tie(c_grid_buf),
scatter_offsets,
scatter_weights);
scatter_offsets);
if constexpr(access_id < num_access - 1)
{