Add MoE & FP8 Blockscale WP Kernels for GFX950 (#2297)

* [fix] align v3 gufusion pipeline

* fix device kernel selection.

* Add .co direct asm support by CK_USE_ASM_MOE_STAGE2_BLOCKSCALE

* experimental optimization for scale load in blkscale gemm

* Add asm for no-loop v3_128x128x128

* fix bugs

* tune fp8 example

* Update v1_128x128x128 to 2x2 instead of 4x1

* wip

* add warmup to asm launch

* wip2

* 16x16 function merged to moe

* temp save, a performant version.

* wip3

* Update .co binary to 16x16

* 16x16x128 correct; 64x64x128 failed

* update

* use mem_op::set when topk=1

* add mx fp8 b_preshuffle support, function not yet tested.

* Spilt the fp4 target. Fix the known bugs. 128x128x128 sanity checked; remove prints

* some fixes

* fix update

* remove some unnecessary hacky; enable 256x256x256 tilesize

* update for function debug

* Add pipeline v3. Have some runtime issue and register spill

* Fix pipe v3 correctness issue

* remove unnecessary hacky

* clang format

* fix a bug

* fix the bug, functional test passed

* tempsave; buggy at passed 4 e8m0 to scaled mfma

* added fp4_bpreshuffle example, build failures

* fixed some bugs

* implement shuffled scale mxfp4gemm, blocker: opsel not effect

* hotfix

* fix bugs, build passed

* (M, N, K)=(128, 128, 128) function failed.

* temp save for gemm1. Function not ready

* fix compile error. Gemm2 pass. Gemm1 WIP

* fix bug for a lds read

* update moe

* Compile pass. Gemm1 function WIP

* update moe

* fix fp8; fix even/odd

* tempsave

* update moe

* Revert "update"

This reverts commit 960b2bce1c.

* Revert "use mem_op::set when topk=1"

This reverts commit def952a178.

* Add v3 128x128x128_4x4_16x16.co for gfx950

* temp cmake flag suppression  for aiter test

* add code for mxfp4 gemm, blockscale not supported yet

* gemm1 up-only pass. GU WIP

* function pass with inline asm hacky

* revert unexpected file change

* updated and build passed

* update CE elementOP

* added code for debug

* Gemm1 GUFusion function pass. Perf WIP

* Fix fp8/bf8; remove duplicated code

* disable the scheduler in v3; bring it back when compiler feature ready.

* update moe v1 pipeline

* Add gemm1 v1 32x128x128

* remove schedule barrier

* updated

* Fix fp8/bf8 B-row

* mfma using asm, device result correct, host result need to check

* gemm1 v3 64x128x128 debug

* fix cpu ref

* a/b thread_desc stride fix

* Use random scale for init1

* 16x16x128 input size blockscale function passed

* fix blockscale gemm bug

* tempsave. Almost all instances passed.

* v1 fix for mi350.

* temp save

* debug save

* update debug

* fix the bug, 128x128x256 tile function passed

* v3

* rename moe block selector and pipeline

* Add gemm1 v1

* Add gemm1 v1 to selector

* added mx moe block v3 support, function passed

* compile error fix

* Improve the pipeline

* Pack e8m0 as int32_t

* v1 compile pass. Function not ready

* debug synchronize issue over different GPU/ROCm

* minor fix

* Add profiler filter

* Add f4 ckProfiler

* Fix example compile error

* Add f4 profiler examples

* tempsave

* v1 function pass.

* v3 function pass

* align file and function name

* mx_moe_fp4 ready for aiter with clang-format.

* modify the way we represent fp4

* generalize the pipeline scheduling.

* init moe mx f4 scale shuffle

* Cmakelist diable compiler-bound flags

* mx_fp4 default parameter change

* Moe blockscale gemm1&gemm2 asm support for aiter. Suppression cmkae flag til new compler.

* update code

* tempsave; modify the way we represent fp4

* generalize the pipeline scheduling.

* Add gemm1 gfx942 .co support

* updated code, build passed.

* Update gemm2 asm with latest compiler flag

* Fix mx f4 ckProfiler

* Fix blockwise gemm mx v1

* lds conflict free + buffer load lds

* Add gemm2 v3 64x128x128

* fix a, b scale loading bugs, a, b scale loading now correctly

* Add gemm2 v3 64x128x128

* commit with debug info

* fix fp4 profiler

* Add mx fp4 pileline v1 instances

* Fix v2 topk_weight cal. Add silu asm.

* v2 tok_weight WIP

* init mx fp4 B no preshuffle version

* tempsave. compile pass, function wrong

* enable fp4 moe no weigth preshuffle, function pass

* update the TFlops calculation in the example

* Add gemm2 64x128x128 asm. Fix BF16 ref.

* fix 2 typos in fp4_preshuffle

* Better kernel selection in device classes

* correct preShuffleBuffer

we should used packed k to do shuffle.

* lds conflict free + buffer load lds

* optimize offset math in dma

* Fix fp4 ckProfiler

* Fix MX MFMA tests

* fix f4 pipeline issues

* gemm1 func pass

* update mx moe gemm1_bns tile size to 64x128x256

* update mx moe gemm1 gemm2 TF and BW calculation

* fix typo

* temp save

* Fix example_gemm_mx build

* rename the block pipeline

* correct a typo in tail

* Add rotating to mx examples

* fix the correctness issue

* Fix v1; use M padding

* Add NT flag to B/BScale buffer

* Merge gemm_mx_common.hpp

* temp save, 4.4~4.5

* Fix 'Merge gemm_mx_common.hpp'

* refactor the pipeline

* Pad the M for scale buffer unconditionaly

* update MX moe GEMM1 hotloopscheduling

* change the gemm1 tile from 64x128x128 to 128x64x128

* Unconditional Ascale padding

* Pad shuffled a scale only

* pad ascale

* add vmcnt guard for async copy

* Profiler add f4 wp

* Merge preshuffle device

* Add more fp4 wp instances

* Fix do_weight in gemm1. Fix cshuffle_datatype. Clang-format

* Clang-format after 2 merges

* Remove rocm6.3 workaround flags and macro

* Fix fp8 config

* Fix bf8 config

* flag and barrier fix for copmiler branch MainOpSelV3

* Add fp8 profiler instances

* Remove debug infos; Enable flags for blockscale f8

* No asm ver. for merging moe blocksale fp8 into mainline

* update the flag name for f8blockscale

* recover example

* fix performance bug of bpreshuffle f8 gemm

* clang format, remove  single rate mfma restriction for f8

* remove single rate mfma restriction for f8 blockscale gemm

* Fix moe blockscale gemm1 barrier 0x800 for new compiler

* add pipeline v1 for MOE Gemm2

* Use v1 pipeline for example_moe_gemm2_xdl_mx_fp4_bns

* Fix OOB; add MB96 instances

* remove unnecessary files

* fix the cmake issue

* Enable splitk for mxfp4; clang format;

* Generate random tensor values with multiple threads

* Use packed_size_v for A/BPackedSize

* Fix warning

* Fix target_compile_options for disabled target on gfx942

* fix moe pki4 on gfx950

* doc the kGroup definition

* Fix ThreadwiseTensorSliceTransfer_v4::Run (Fuse scale)

* Refactor thread_copy_lds_direct_load; fix gfx942 direct lds load example; fix f16_pki4 example

* Fix unknown compiler flag

* fix two failed examples.

* fix some failure tile size in gfx950 universal gemm. fix test_gemm_fp16

* workaround fix for test_gemm_f32; * We have very limited support for lds direct load if input matrix is not K major

* fix test_gemm_splitk;

* Fix compile for mx_mfma_op

* add mfma selection logic for multipled_v3

* Clean up

* Fix device gemm mx link error

* improve the global atomic pattern

* Revert unnecessary copyright updates

* restore minimum_occupancy logic

* Avoid data race in moe gemm2 ref

* Build fp8 gemm_multiply_multiply and moe only on gfx94/95

* update the instance in device_mx_gemm

* Resolve comments

* Copyright 2025

* Remove unused code

* fix library linking issue

---------

Co-authored-by: OscarXu <huaiguxu@amd.com>
Co-authored-by: lalala-sh <Jiaxing.Wen@amd.com>
Co-authored-by: mtgu0705 <mtgu@amd.com>
Co-authored-by: aska-0096 <haocwang@amd.com>
Co-authored-by: Your Name <you@example.com>
Co-authored-by: valarLip <340077269@qq.com>
Co-authored-by: feifei14119 <feiw@amd.com>
Co-authored-by: Lin, Qun <qlin@amd.com>
Co-authored-by: Andriy Roshchenko <andriy.roshchenko@amd.com>
Co-authored-by: joye <joye@amd.com>
Co-authored-by: asleepzzz <hanwen.chang@amd.com>
This commit is contained in:
Yi DING
2025-06-12 09:25:59 +08:00
committed by GitHub
parent 8c1ed6f4c1
commit 37554c31e8
85 changed files with 32508 additions and 431 deletions

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -153,9 +153,7 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
static constexpr bool is_single_rate_mfma =
(((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4) ||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8) ||
((is_same<ComputeTypeA, f8_t>::value || is_same<ComputeTypeA, bf8_t>::value) &&
lcm_AK1_BK1 < 32))
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8))
? true
: false;
static constexpr auto is_scale_mfma = false;

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -168,9 +168,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
static constexpr bool is_single_rate_mfma =
(((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4) ||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8) ||
((is_same<ComputeTypeA, f8_t>::value || is_same<ComputeTypeA, bf8_t>::value) &&
lcm_AK1_BK1 < 32))
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8))
? true
: false;
static constexpr auto is_scale_mfma = false;
@@ -1192,7 +1190,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
// N0, K0, Blocksize*KPack
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
@@ -1200,7 +1197,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// B matrix in LDS memory, dst of blockwise copy
// dummy
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// A matrix blockwise copy
@@ -1629,7 +1625,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
// N0, K0, Blocksize*KPack
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
@@ -1637,7 +1632,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// B matrix in LDS memory, dst of blockwise copy
// dummy
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// A matrix blockwise copy

View File

@@ -1221,7 +1221,6 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle
}
}
}
#if 0
// check gridwise gemm pipeline
const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
@@ -1232,7 +1231,6 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle
return false;
}
}
#endif
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true;
}
@@ -2123,6 +2121,58 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle
n_thread_data_on_block_idx[I3]),
ck::tensor_operation::element_wise::PassThrough{}};
// calculate C grid descriptor
constexpr auto DWORD_BYTES = 4;
constexpr auto atomic_vector_size = DWORD_BYTES / sizeof(CDataType);
constexpr auto CShuffleBlockTransferClusterLengths = [&]() {
if constexpr(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::Set)
{
return CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock{};
}
// Atomic operation
else
{
return generate_sequence_v2(
[&](auto i) {
if constexpr(i == 3)
{
return Number<
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock{}
.At(i) *
CShuffleBlockTransferScalarPerVector_NPerBlock /
atomic_vector_size>{};
}
else if constexpr(i == 1)
{
return Number<
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock{}
.At(i) /
CShuffleBlockTransferScalarPerVector_NPerBlock *
atomic_vector_size>{};
}
else
{
return Number<
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock{}
.At(i)>{};
}
},
Number<4>{});
}
}();
constexpr auto CShuffleBlockTransferScalarPerVector = [&]() {
if constexpr(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::Set)
{
return CShuffleBlockTransferScalarPerVector_NPerBlock;
}
else
{
return atomic_vector_size;
}
}();
// shuffle: blockwise copy C from LDS to global
auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
ThisThreadBlock, // ThreadGroup
@@ -2132,15 +2182,15 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
decltype(CShuffleBlockTransferClusterLengths),
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
CShuffleDataType, // typename SrcData,
CDataType, // typename DstData,
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CShuffleBlockTransferScalarPerVector, // index_t ScalarPerVector,
true, // bool ThreadTransferSrcResetCoordinateAfterRun,
false> // bool ThreadTransferDstResetCoordinateAfterRun>
{c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,

View File

@@ -183,27 +183,28 @@ struct GridwiseMoeGemm
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
static constexpr bool is_single_rate_mfma =
(((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4) ||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8) ||
((is_same<ComputeTypeA, f8_t>::value || is_same<ComputeTypeA, bf8_t>::value) &&
lcm_AK1_BK1 < 32))
? true
: false;
static constexpr auto is_scale_mfma = false;
static constexpr auto mfma = MfmaSelector<ComputeTypeA,
MPerXdl,
NPerXdl,
ComputeTypeA,
is_single_rate_mfma,
is_scale_mfma>{};
static constexpr index_t KPack = math::max(lcm_AK1_BK1, mfma.selected_mfma.k_per_blk);
static constexpr index_t KLane = mfma.GetKPerXdlops() / mfma.GetK1PerXdlops();
static constexpr index_t KRepeat = KPerBlock / KLane / KPack;
static constexpr index_t NLane = NPerXdl;
static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave;
using mfma_selector = MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeB>;
static constexpr index_t KPack =
math::max(math::lcm(AK1Number, BK1Number), mfma_selector::selected_mfma.k_per_blk);
static constexpr index_t KLane =
mfma_selector::GetKPerXdlops() / mfma_selector::GetK1PerXdlops();
static constexpr index_t KGroup = []() {
if constexpr(is_same_v<remove_cvref_t<BDataType>, f8_t>)
// On gfx950, we have a mfma that required 32 f8 elements as input,
// splited into 2 groups of 16 f8 elements.
// the 2 groups is not contiguous in the B preshuffed layout.
// and we do not want it to be contiguous in the B preshuffled layout
// because a memory instruction can only read 16 f8 elements at a time.
return mfma_selector::selected_mfma.k_per_blk == 32 ? 2 : 1;
else
return 1;
}();
static constexpr index_t KRepeat = KPerBlock / KLane / (KPack / KGroup);
static constexpr index_t NLane = NPerXdl;
static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave;
// static constexpr index_t NumTokens = 1;
static constexpr index_t SortedTileSize = MPerBlock;
@@ -262,7 +263,7 @@ struct GridwiseMoeGemm
}
__host__ __device__ static auto CalculateBK0Shuffled(index_t K)
{
return math::integer_divide_ceil(K, KLane * KPack);
return math::integer_divide_ceil(K, KLane * KPack / KGroup);
}
__host__ __device__ static auto CalculateKPadded(index_t K)
@@ -404,7 +405,7 @@ struct GridwiseMoeGemm
__host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
{
constexpr index_t NkSwizzleNumber = Number<warpSize * KPack>{};
constexpr index_t NkSwizzleNumber = Number<warpSize * KPack / KGroup>{};
return make_naive_tensor_descriptor(
make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber),
make_tuple(NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1));
@@ -1314,7 +1315,7 @@ struct GridwiseMoeGemm
make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id() % NWave,
0,
KPack * (get_thread_local_1d_id() % warpSize)));
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
// LDS allocation for A and B: be careful of alignment
// Cast after lds
@@ -1360,7 +1361,7 @@ struct GridwiseMoeGemm
make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id() % NWave,
0,
KPack * (get_thread_local_1d_id() % warpSize)));
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
@@ -1899,7 +1900,8 @@ struct GridwiseMoeGemm
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
// static_assert(NSwizzle == false, "to do fix: need another pr in sorting merged");
const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
if(expert_block_id * MPerBlock >= max_token_id)
return;
@@ -1908,12 +1910,13 @@ struct GridwiseMoeGemm
const auto block_mn = [&]() -> std::pair<int, int> {
if constexpr(NSwizzle)
{
const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
const index_t prefix_block = ecnt_prefix * problem.NBlock;
const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
const index_t expert_swizzle = ecnt > 0 ? ecnt : 1;
const index_t bid_new = blockIdx.x - prefix_block;
const index_t nid = __builtin_amdgcn_readfirstlane(
const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
const index_t prefix_block = ecnt_prefix * problem.NBlock;
const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
const index_t expert_swizzle =
ecnt > 0 ? ecnt : 1; // p_max_token_id[expert_id + 1]; // 2
const index_t bid_new = blockIdx.x - prefix_block;
const index_t nid = __builtin_amdgcn_readfirstlane(
bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
const index_t mid =
__builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
@@ -1924,9 +1927,9 @@ struct GridwiseMoeGemm
return {blockIdx.x, blockIdx.y};
}
}();
const index_t block_n_id = block_mn.first;
const index_t block_m_id = block_mn.second;
const index_t token0 =
__builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
@@ -1938,11 +1941,9 @@ struct GridwiseMoeGemm
constexpr auto AMRepeats = MPerBlock / AMThreads;
const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
if(token_pos >= max_token_id || expert_block_id * MPerBlock >= max_token_id ||
token0 >= problem.NumTokens)
if(token_pos >= max_token_id || token0 >= problem.NumTokens)
return;
StaticallyIndexedArray<IndexType, AMRepeats>
gather_offsets; //= p_sorted_token_ids[token_pos];
StaticallyIndexedArray<IndexType, AMRepeats> gather_offsets;
static_for<0, AMRepeats, 1>{}([&](auto m0) {
const index_t fused_token = p_sorted_token_ids[token_pos + m0];
index_t token_offset = fused_token & 0xffffff;
@@ -1952,7 +1953,8 @@ struct GridwiseMoeGemm
}
gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
});
const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K);
const index_t expert_stride =
__builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
// N0, K0, Blocksize*KPack
const index_t n_block_data_idx_on_grid =
@@ -2025,7 +2027,7 @@ struct GridwiseMoeGemm
make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id() % NWave,
0,
KPack * (get_thread_local_1d_id() % warpSize)));
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
// LDS allocation for A and B: be careful of alignment
// Cast after lds
@@ -2042,24 +2044,76 @@ struct GridwiseMoeGemm
static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
decltype(c_thread_buf) c_thread_buf_up;
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
float,
c_thread_buf.num_of_v_,
c_thread_buf.s_per_v,
true>
c_thread_buf_fp32;
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
KPerBlock);
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
a_grid_buf,
a_block_bufs,
a_block_slice_copy_step,
b_grid_desc_bpreshuffled,
b_blockwise_copy,
b_grid_buf,
b_block_bufs,
b_block_slice_copy_step,
c_thread_buf,
num_k_block_main_loop);
if constexpr(IsInputGemm)
{
const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize;
const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid_up + expert_id * expert_stride / BPackedSize,
b_grid_desc_bpreshuffled.GetElementSpaceSize());
auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2<
BDataType,
BDataType,
decltype(b_grid_desc_bpreshuffled),
decltype(b_block_desc_bk0_n_bk1),
Sequence<Number<NXdlPerWave>{}, I1, Number<KRepeat>{}, Number<BK1Value>{}>,
Sequence<1, 2, 0, 3>,
3,
BBlockTransferSrcScalarPerVector,
BThreadTransferSrcResetCoordinateAfterRun,
true>(b_grid_desc_bpreshuffled,
make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id() % NWave,
0,
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
a_grid_buf,
a_block_bufs,
a_block_slice_copy_step,
b_grid_desc_bpreshuffled,
b_blockwise_copy,
b_blockwise_copy_up,
b_grid_buf,
b_grid_buf_up,
b_block_bufs,
b_block_slice_copy_step,
c_thread_buf,
c_thread_buf_up,
num_k_block_main_loop);
}
else
{
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
a_grid_buf,
a_block_bufs,
a_block_slice_copy_step,
b_grid_desc_bpreshuffled,
b_blockwise_copy,
b_grid_buf,
b_block_bufs,
b_block_slice_copy_step,
c_thread_buf,
num_k_block_main_loop);
}
// shuffle C and write out
{
@@ -2087,6 +2141,185 @@ struct GridwiseMoeGemm
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
// mul scales
const float* p_sorted_weights_0 = p_ds_grid[I0];
const float* p_scale_b = p_ds_grid[I1];
static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock);
static_assert(M4 == 4);
const index_t m1 = get_warp_local_1d_id() / NWave;
const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl;
if(p_sorted_weights_0 != nullptr && p_scale_b != nullptr)
{
if constexpr(PerTokenQuant)
{
constexpr index_t scale_stride = (IsInputGemm ? 2 : 1);
p_scale_b += expert_id * problem.N * scale_stride + block_n_id * NPerBlock +
get_warp_local_1d_id() % NWave * NPerXdl + threadIdx.x % NPerXdl;
}
else
{
p_scale_b += expert_id;
}
vector_type<int32_t, 4> scale_token_ids;
vector_type<float, 4> topk_weights;
static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
const float scale_b = p_scale_b[n0 * NWave * NPerXdl * PerTokenQuant];
static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk
const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 +
m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
if constexpr(PerTokenQuant)
{
scale_token_ids =
*c_style_pointer_cast<const vector_type<int32_t, M4>*>(
p_sorted_token_ids + m_pos);
}
if constexpr(MulRoutedWeight)
{
topk_weights = *c_style_pointer_cast<const vector_type<float, M4>*>(
p_ds_grid[I2] + m_pos);
}
static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size
float scale_a = [&]() {
if constexpr(PerTokenQuant)
{
index_t fused_token = scale_token_ids.AsType<index_t>()[m4];
const index_t token_offset = fused_token & 0xffffff;
return token_offset < problem.NumTokens
? p_sorted_weights_0[token_offset]
: 0.0;
}
else
{
return p_sorted_weights_0[0];
}
}();
constexpr index_t c_offset =
blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
make_tuple(m0, n0, m2 * M4 + m4));
constexpr auto cidx = Number<c_offset>{};
if constexpr(IsInputGemm) // gu fusion
{
if constexpr(ActivationOperation == Activation::silu_and_mul)
{
const float scale_up =
p_scale_b[(n0 * NWave * NPerXdl + problem.N) *
PerTokenQuant];
float gate = scale_a * scale_b * c_thread_buf[cidx];
float up = scale_a * scale_up * c_thread_buf_up[cidx];
if constexpr(MulRoutedWeight)
{
gate = gate * topk_weights.AsType<float>()[m4];
up = up * topk_weights.AsType<float>()[m4];
}
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
{
gate *= 16;
up *= 16;
}
tensor_operation::element_wise::Silu{}(gate, gate);
c_thread_buf_fp32(cidx) = gate * up;
}
else if(ActivationOperation == Activation::gelu_and_mul)
{
const float scale_up =
p_scale_b[(n0 * NWave * NPerXdl + problem.N) *
PerTokenQuant];
float gate = scale_a * scale_b * c_thread_buf[cidx];
float up = scale_a * scale_up * c_thread_buf_up[cidx];
if constexpr(MulRoutedWeight)
{
gate = gate * topk_weights.AsType<float>()[m4];
up = up * topk_weights.AsType<float>()[m4];
}
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
{
gate *= 16;
up *= 16;
}
tensor_operation::element_wise::Gelu{}(gate, gate);
c_thread_buf_fp32(cidx) = gate * up;
}
}
else
{
c_thread_buf_fp32(cidx) =
scale_a * scale_b * c_thread_buf[cidx];
if constexpr(MulRoutedWeight)
{
c_thread_buf_fp32(cidx) = c_thread_buf_fp32(cidx) *
topk_weights.AsType<float>()[m4];
}
}
});
});
});
});
}
else
{
vector_type<float, 4> topk_weights; // for gemm2 only
static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk
const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 +
m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
if constexpr(MulRoutedWeight)
{
topk_weights = *c_style_pointer_cast<const vector_type<float, M4>*>(
p_ds_grid[I2] + m_pos);
}
static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size
constexpr index_t c_offset =
blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
make_tuple(m0, n0, m2 * M4 + m4));
constexpr auto cidx = Number<c_offset>{};
if constexpr(IsInputGemm) // gu fusion
{
if constexpr(ActivationOperation == Activation::silu_and_mul)
{
float gate = c_thread_buf[cidx];
float up = c_thread_buf_up[cidx];
if constexpr(MulRoutedWeight)
{
gate = gate * topk_weights.AsType<float>()[m4];
up = up * topk_weights.AsType<float>()[m4];
}
tensor_operation::element_wise::Silu{}(gate, gate);
c_thread_buf_fp32(cidx) = gate * up;
}
else if(ActivationOperation == Activation::gelu_and_mul)
{
float gate = c_thread_buf[cidx];
float up = c_thread_buf_up[cidx];
if constexpr(MulRoutedWeight)
{
gate = gate * topk_weights.AsType<float>()[m4];
up = up * topk_weights.AsType<float>()[m4];
}
tensor_operation::element_wise::Gelu{}(gate, gate);
c_thread_buf_fp32(cidx) = gate * up;
}
}
else
{
c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
if constexpr(MulRoutedWeight)
{
c_thread_buf_fp32(cidx) = topk_weights.AsType<float>()[m4] *
c_thread_buf_fp32[cidx];
}
}
});
});
});
});
}
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
@@ -2184,18 +2417,8 @@ struct GridwiseMoeGemm
const auto ds_grid_buf = generate_tuple(
[&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
const DDataType* ptr_ = p_ds_grid[i];
// hack logic here to support different kind of strides. todo fix it.
// ascale t, 1; bscale E, N, 1, move ptr to E
// if(i.value == 1)
// {
// ptr_ +=
// expert_id * (problem.StrideDs[1] ? problem.StrideDs[1] * problem.N :
// 1);
// }
return make_dynamic_buffer<AddressSpaceEnum::Global>(
ptr_, ds_grid_desc_m_n[i].GetElementSpaceSize());
p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
},
Number<NumDTensor>{});
@@ -2271,7 +2494,6 @@ struct GridwiseMoeGemm
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// space filling curve for threadwise C in VGPR
constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
@@ -2310,7 +2532,7 @@ struct GridwiseMoeGemm
block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
static_for<0, EMRepeats, 1>{}([&](auto m0) {
const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
index_t token_offset = fused_token & 0xffffff;
IndexType token_offset = fused_token & 0xffffff;
if constexpr(IsInputGemm)
{
token_offset = token_offset * problem.TopK + (fused_token >> 24);
@@ -2323,7 +2545,7 @@ struct GridwiseMoeGemm
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
c_thread_buf,
c_thread_buf_fp32,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_shuffle_block_buf);

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff