mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
Add MoE & FP8 Blockscale WP Kernels for GFX950 (#2297)
* [fix] align v3 gufusion pipeline * fix device kernel selection. * Add .co direct asm support by CK_USE_ASM_MOE_STAGE2_BLOCKSCALE * experimental optimization for scale load in blkscale gemm * Add asm for no-loop v3_128x128x128 * fix bugs * tune fp8 example * Update v1_128x128x128 to 2x2 instead of 4x1 * wip * add warmup to asm launch * wip2 * 16x16 function merged to moe * temp save, a performant version. * wip3 * Update .co binary to 16x16 * 16x16x128 correct; 64x64x128 failed * update * use mem_op::set when topk=1 * add mx fp8 b_preshuffle support, function not yet tested. * Spilt the fp4 target. Fix the known bugs. 128x128x128 sanity checked; remove prints * some fixes * fix update * remove some unnecessary hacky; enable 256x256x256 tilesize * update for function debug * Add pipeline v3. Have some runtime issue and register spill * Fix pipe v3 correctness issue * remove unnecessary hacky * clang format * fix a bug * fix the bug, functional test passed * tempsave; buggy at passed 4 e8m0 to scaled mfma * added fp4_bpreshuffle example, build failures * fixed some bugs * implement shuffled scale mxfp4gemm, blocker: opsel not effect * hotfix * fix bugs, build passed * (M, N, K)=(128, 128, 128) function failed. * temp save for gemm1. Function not ready * fix compile error. Gemm2 pass. Gemm1 WIP * fix bug for a lds read * update moe * Compile pass. Gemm1 function WIP * update moe * fix fp8; fix even/odd * tempsave * update moe * Revert "update" This reverts commit960b2bce1c. * Revert "use mem_op::set when topk=1" This reverts commitdef952a178. * Add v3 128x128x128_4x4_16x16.co for gfx950 * temp cmake flag suppression for aiter test * add code for mxfp4 gemm, blockscale not supported yet * gemm1 up-only pass. GU WIP * function pass with inline asm hacky * revert unexpected file change * updated and build passed * update CE elementOP * added code for debug * Gemm1 GUFusion function pass. Perf WIP * Fix fp8/bf8; remove duplicated code * disable the scheduler in v3; bring it back when compiler feature ready. * update moe v1 pipeline * Add gemm1 v1 32x128x128 * remove schedule barrier * updated * Fix fp8/bf8 B-row * mfma using asm, device result correct, host result need to check * gemm1 v3 64x128x128 debug * fix cpu ref * a/b thread_desc stride fix * Use random scale for init1 * 16x16x128 input size blockscale function passed * fix blockscale gemm bug * tempsave. Almost all instances passed. * v1 fix for mi350. * temp save * debug save * update debug * fix the bug, 128x128x256 tile function passed * v3 * rename moe block selector and pipeline * Add gemm1 v1 * Add gemm1 v1 to selector * added mx moe block v3 support, function passed * compile error fix * Improve the pipeline * Pack e8m0 as int32_t * v1 compile pass. Function not ready * debug synchronize issue over different GPU/ROCm * minor fix * Add profiler filter * Add f4 ckProfiler * Fix example compile error * Add f4 profiler examples * tempsave * v1 function pass. * v3 function pass * align file and function name * mx_moe_fp4 ready for aiter with clang-format. * modify the way we represent fp4 * generalize the pipeline scheduling. * init moe mx f4 scale shuffle * Cmakelist diable compiler-bound flags * mx_fp4 default parameter change * Moe blockscale gemm1&gemm2 asm support for aiter. Suppression cmkae flag til new compler. * update code * tempsave; modify the way we represent fp4 * generalize the pipeline scheduling. * Add gemm1 gfx942 .co support * updated code, build passed. * Update gemm2 asm with latest compiler flag * Fix mx f4 ckProfiler * Fix blockwise gemm mx v1 * lds conflict free + buffer load lds * Add gemm2 v3 64x128x128 * fix a, b scale loading bugs, a, b scale loading now correctly * Add gemm2 v3 64x128x128 * commit with debug info * fix fp4 profiler * Add mx fp4 pileline v1 instances * Fix v2 topk_weight cal. Add silu asm. * v2 tok_weight WIP * init mx fp4 B no preshuffle version * tempsave. compile pass, function wrong * enable fp4 moe no weigth preshuffle, function pass * update the TFlops calculation in the example * Add gemm2 64x128x128 asm. Fix BF16 ref. * fix 2 typos in fp4_preshuffle * Better kernel selection in device classes * correct preShuffleBuffer we should used packed k to do shuffle. * lds conflict free + buffer load lds * optimize offset math in dma * Fix fp4 ckProfiler * Fix MX MFMA tests * fix f4 pipeline issues * gemm1 func pass * update mx moe gemm1_bns tile size to 64x128x256 * update mx moe gemm1 gemm2 TF and BW calculation * fix typo * temp save * Fix example_gemm_mx build * rename the block pipeline * correct a typo in tail * Add rotating to mx examples * fix the correctness issue * Fix v1; use M padding * Add NT flag to B/BScale buffer * Merge gemm_mx_common.hpp * temp save, 4.4~4.5 * Fix 'Merge gemm_mx_common.hpp' * refactor the pipeline * Pad the M for scale buffer unconditionaly * update MX moe GEMM1 hotloopscheduling * change the gemm1 tile from 64x128x128 to 128x64x128 * Unconditional Ascale padding * Pad shuffled a scale only * pad ascale * add vmcnt guard for async copy * Profiler add f4 wp * Merge preshuffle device * Add more fp4 wp instances * Fix do_weight in gemm1. Fix cshuffle_datatype. Clang-format * Clang-format after 2 merges * Remove rocm6.3 workaround flags and macro * Fix fp8 config * Fix bf8 config * flag and barrier fix for copmiler branch MainOpSelV3 * Add fp8 profiler instances * Remove debug infos; Enable flags for blockscale f8 * No asm ver. for merging moe blocksale fp8 into mainline * update the flag name for f8blockscale * recover example * fix performance bug of bpreshuffle f8 gemm * clang format, remove single rate mfma restriction for f8 * remove single rate mfma restriction for f8 blockscale gemm * Fix moe blockscale gemm1 barrier 0x800 for new compiler * add pipeline v1 for MOE Gemm2 * Use v1 pipeline for example_moe_gemm2_xdl_mx_fp4_bns * Fix OOB; add MB96 instances * remove unnecessary files * fix the cmake issue * Enable splitk for mxfp4; clang format; * Generate random tensor values with multiple threads * Use packed_size_v for A/BPackedSize * Fix warning * Fix target_compile_options for disabled target on gfx942 * fix moe pki4 on gfx950 * doc the kGroup definition * Fix ThreadwiseTensorSliceTransfer_v4::Run (Fuse scale) * Refactor thread_copy_lds_direct_load; fix gfx942 direct lds load example; fix f16_pki4 example * Fix unknown compiler flag * fix two failed examples. * fix some failure tile size in gfx950 universal gemm. fix test_gemm_fp16 * workaround fix for test_gemm_f32; * We have very limited support for lds direct load if input matrix is not K major * fix test_gemm_splitk; * Fix compile for mx_mfma_op * add mfma selection logic for multipled_v3 * Clean up * Fix device gemm mx link error * improve the global atomic pattern * Revert unnecessary copyright updates * restore minimum_occupancy logic * Avoid data race in moe gemm2 ref * Build fp8 gemm_multiply_multiply and moe only on gfx94/95 * update the instance in device_mx_gemm * Resolve comments * Copyright 2025 * Remove unused code * fix library linking issue --------- Co-authored-by: OscarXu <huaiguxu@amd.com> Co-authored-by: lalala-sh <Jiaxing.Wen@amd.com> Co-authored-by: mtgu0705 <mtgu@amd.com> Co-authored-by: aska-0096 <haocwang@amd.com> Co-authored-by: Your Name <you@example.com> Co-authored-by: valarLip <340077269@qq.com> Co-authored-by: feifei14119 <feiw@amd.com> Co-authored-by: Lin, Qun <qlin@amd.com> Co-authored-by: Andriy Roshchenko <andriy.roshchenko@amd.com> Co-authored-by: joye <joye@amd.com> Co-authored-by: asleepzzz <hanwen.chang@amd.com>
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -153,9 +153,7 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
|
||||
static constexpr bool is_single_rate_mfma =
|
||||
(((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4) ||
|
||||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8) ||
|
||||
((is_same<ComputeTypeA, f8_t>::value || is_same<ComputeTypeA, bf8_t>::value) &&
|
||||
lcm_AK1_BK1 < 32))
|
||||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8))
|
||||
? true
|
||||
: false;
|
||||
static constexpr auto is_scale_mfma = false;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -168,9 +168,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
|
||||
static constexpr bool is_single_rate_mfma =
|
||||
(((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4) ||
|
||||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8) ||
|
||||
((is_same<ComputeTypeA, f8_t>::value || is_same<ComputeTypeA, bf8_t>::value) &&
|
||||
lcm_AK1_BK1 < 32))
|
||||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8))
|
||||
? true
|
||||
: false;
|
||||
static constexpr auto is_scale_mfma = false;
|
||||
@@ -1192,7 +1190,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
|
||||
const index_t m_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
|
||||
|
||||
// N0, K0, Blocksize*KPack
|
||||
const index_t n_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
|
||||
|
||||
@@ -1200,7 +1197,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
|
||||
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// dummy
|
||||
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
|
||||
|
||||
// A matrix blockwise copy
|
||||
@@ -1629,7 +1625,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
|
||||
const index_t m_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
|
||||
|
||||
// N0, K0, Blocksize*KPack
|
||||
const index_t n_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
|
||||
|
||||
@@ -1637,7 +1632,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
|
||||
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// dummy
|
||||
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
|
||||
|
||||
// A matrix blockwise copy
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1221,7 +1221,6 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle
|
||||
}
|
||||
}
|
||||
}
|
||||
#if 0
|
||||
// check gridwise gemm pipeline
|
||||
const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
|
||||
|
||||
@@ -1232,7 +1231,6 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle
|
||||
return false;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
|
||||
return true;
|
||||
}
|
||||
@@ -2123,6 +2121,58 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle
|
||||
n_thread_data_on_block_idx[I3]),
|
||||
ck::tensor_operation::element_wise::PassThrough{}};
|
||||
|
||||
// calculate C grid descriptor
|
||||
constexpr auto DWORD_BYTES = 4;
|
||||
constexpr auto atomic_vector_size = DWORD_BYTES / sizeof(CDataType);
|
||||
|
||||
constexpr auto CShuffleBlockTransferClusterLengths = [&]() {
|
||||
if constexpr(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::Set)
|
||||
{
|
||||
return CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock{};
|
||||
}
|
||||
// Atomic operation
|
||||
else
|
||||
{
|
||||
return generate_sequence_v2(
|
||||
[&](auto i) {
|
||||
if constexpr(i == 3)
|
||||
{
|
||||
return Number<
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock{}
|
||||
.At(i) *
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock /
|
||||
atomic_vector_size>{};
|
||||
}
|
||||
else if constexpr(i == 1)
|
||||
{
|
||||
return Number<
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock{}
|
||||
.At(i) /
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock *
|
||||
atomic_vector_size>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return Number<
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock{}
|
||||
.At(i)>{};
|
||||
}
|
||||
},
|
||||
Number<4>{});
|
||||
}
|
||||
}();
|
||||
|
||||
constexpr auto CShuffleBlockTransferScalarPerVector = [&]() {
|
||||
if constexpr(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::Set)
|
||||
{
|
||||
return CShuffleBlockTransferScalarPerVector_NPerBlock;
|
||||
}
|
||||
else
|
||||
{
|
||||
return atomic_vector_size;
|
||||
}
|
||||
}();
|
||||
|
||||
// shuffle: blockwise copy C from LDS to global
|
||||
auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
|
||||
ThisThreadBlock, // ThreadGroup
|
||||
@@ -2132,15 +2182,15 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle
|
||||
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
|
||||
1,
|
||||
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
decltype(CShuffleBlockTransferClusterLengths),
|
||||
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
|
||||
CShuffleDataType, // typename SrcData,
|
||||
CDataType, // typename DstData,
|
||||
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
|
||||
3, // index_t VectorDim,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
|
||||
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
|
||||
3, // index_t VectorDim,
|
||||
CShuffleBlockTransferScalarPerVector, // index_t ScalarPerVector,
|
||||
true, // bool ThreadTransferSrcResetCoordinateAfterRun,
|
||||
false> // bool ThreadTransferDstResetCoordinateAfterRun>
|
||||
{c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
|
||||
|
||||
@@ -183,27 +183,28 @@ struct GridwiseMoeGemm
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
|
||||
static constexpr bool is_single_rate_mfma =
|
||||
(((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4) ||
|
||||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8) ||
|
||||
((is_same<ComputeTypeA, f8_t>::value || is_same<ComputeTypeA, bf8_t>::value) &&
|
||||
lcm_AK1_BK1 < 32))
|
||||
? true
|
||||
: false;
|
||||
static constexpr auto is_scale_mfma = false;
|
||||
static constexpr auto mfma = MfmaSelector<ComputeTypeA,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
ComputeTypeA,
|
||||
is_single_rate_mfma,
|
||||
is_scale_mfma>{};
|
||||
static constexpr index_t KPack = math::max(lcm_AK1_BK1, mfma.selected_mfma.k_per_blk);
|
||||
static constexpr index_t KLane = mfma.GetKPerXdlops() / mfma.GetK1PerXdlops();
|
||||
static constexpr index_t KRepeat = KPerBlock / KLane / KPack;
|
||||
static constexpr index_t NLane = NPerXdl;
|
||||
static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave;
|
||||
using mfma_selector = MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeB>;
|
||||
static constexpr index_t KPack =
|
||||
math::max(math::lcm(AK1Number, BK1Number), mfma_selector::selected_mfma.k_per_blk);
|
||||
static constexpr index_t KLane =
|
||||
mfma_selector::GetKPerXdlops() / mfma_selector::GetK1PerXdlops();
|
||||
|
||||
static constexpr index_t KGroup = []() {
|
||||
if constexpr(is_same_v<remove_cvref_t<BDataType>, f8_t>)
|
||||
// On gfx950, we have a mfma that required 32 f8 elements as input,
|
||||
// splited into 2 groups of 16 f8 elements.
|
||||
// the 2 groups is not contiguous in the B preshuffed layout.
|
||||
// and we do not want it to be contiguous in the B preshuffled layout
|
||||
// because a memory instruction can only read 16 f8 elements at a time.
|
||||
return mfma_selector::selected_mfma.k_per_blk == 32 ? 2 : 1;
|
||||
else
|
||||
return 1;
|
||||
}();
|
||||
|
||||
static constexpr index_t KRepeat = KPerBlock / KLane / (KPack / KGroup);
|
||||
|
||||
static constexpr index_t NLane = NPerXdl;
|
||||
static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave;
|
||||
// static constexpr index_t NumTokens = 1;
|
||||
static constexpr index_t SortedTileSize = MPerBlock;
|
||||
|
||||
@@ -262,7 +263,7 @@ struct GridwiseMoeGemm
|
||||
}
|
||||
__host__ __device__ static auto CalculateBK0Shuffled(index_t K)
|
||||
{
|
||||
return math::integer_divide_ceil(K, KLane * KPack);
|
||||
return math::integer_divide_ceil(K, KLane * KPack / KGroup);
|
||||
}
|
||||
|
||||
__host__ __device__ static auto CalculateKPadded(index_t K)
|
||||
@@ -404,7 +405,7 @@ struct GridwiseMoeGemm
|
||||
|
||||
__host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
|
||||
{
|
||||
constexpr index_t NkSwizzleNumber = Number<warpSize * KPack>{};
|
||||
constexpr index_t NkSwizzleNumber = Number<warpSize * KPack / KGroup>{};
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber),
|
||||
make_tuple(NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1));
|
||||
@@ -1314,7 +1315,7 @@ struct GridwiseMoeGemm
|
||||
make_multi_index(n_block_data_idx_on_grid,
|
||||
get_warp_local_1d_id() % NWave,
|
||||
0,
|
||||
KPack * (get_thread_local_1d_id() % warpSize)));
|
||||
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
// Cast after lds
|
||||
@@ -1360,7 +1361,7 @@ struct GridwiseMoeGemm
|
||||
make_multi_index(n_block_data_idx_on_grid,
|
||||
get_warp_local_1d_id() % NWave,
|
||||
0,
|
||||
KPack * (get_thread_local_1d_id() % warpSize)));
|
||||
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
|
||||
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
@@ -1899,7 +1900,8 @@ struct GridwiseMoeGemm
|
||||
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
|
||||
const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
|
||||
const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
|
||||
// static_assert(NSwizzle == false, "to do fix: need another pr in sorting merged");
|
||||
const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
|
||||
if(expert_block_id * MPerBlock >= max_token_id)
|
||||
return;
|
||||
@@ -1908,12 +1910,13 @@ struct GridwiseMoeGemm
|
||||
const auto block_mn = [&]() -> std::pair<int, int> {
|
||||
if constexpr(NSwizzle)
|
||||
{
|
||||
const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
|
||||
const index_t prefix_block = ecnt_prefix * problem.NBlock;
|
||||
const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
|
||||
const index_t expert_swizzle = ecnt > 0 ? ecnt : 1;
|
||||
const index_t bid_new = blockIdx.x - prefix_block;
|
||||
const index_t nid = __builtin_amdgcn_readfirstlane(
|
||||
const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
|
||||
const index_t prefix_block = ecnt_prefix * problem.NBlock;
|
||||
const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
|
||||
const index_t expert_swizzle =
|
||||
ecnt > 0 ? ecnt : 1; // p_max_token_id[expert_id + 1]; // 2
|
||||
const index_t bid_new = blockIdx.x - prefix_block;
|
||||
const index_t nid = __builtin_amdgcn_readfirstlane(
|
||||
bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
|
||||
const index_t mid =
|
||||
__builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
|
||||
@@ -1924,9 +1927,9 @@ struct GridwiseMoeGemm
|
||||
return {blockIdx.x, blockIdx.y};
|
||||
}
|
||||
}();
|
||||
|
||||
const index_t block_n_id = block_mn.first;
|
||||
const index_t block_m_id = block_mn.second;
|
||||
|
||||
const index_t token0 =
|
||||
__builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
|
||||
|
||||
@@ -1938,11 +1941,9 @@ struct GridwiseMoeGemm
|
||||
constexpr auto AMRepeats = MPerBlock / AMThreads;
|
||||
const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
|
||||
|
||||
if(token_pos >= max_token_id || expert_block_id * MPerBlock >= max_token_id ||
|
||||
token0 >= problem.NumTokens)
|
||||
if(token_pos >= max_token_id || token0 >= problem.NumTokens)
|
||||
return;
|
||||
StaticallyIndexedArray<IndexType, AMRepeats>
|
||||
gather_offsets; //= p_sorted_token_ids[token_pos];
|
||||
StaticallyIndexedArray<IndexType, AMRepeats> gather_offsets;
|
||||
static_for<0, AMRepeats, 1>{}([&](auto m0) {
|
||||
const index_t fused_token = p_sorted_token_ids[token_pos + m0];
|
||||
index_t token_offset = fused_token & 0xffffff;
|
||||
@@ -1952,7 +1953,8 @@ struct GridwiseMoeGemm
|
||||
}
|
||||
gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
|
||||
});
|
||||
const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K);
|
||||
const index_t expert_stride =
|
||||
__builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
|
||||
|
||||
// N0, K0, Blocksize*KPack
|
||||
const index_t n_block_data_idx_on_grid =
|
||||
@@ -2025,7 +2027,7 @@ struct GridwiseMoeGemm
|
||||
make_multi_index(n_block_data_idx_on_grid,
|
||||
get_warp_local_1d_id() % NWave,
|
||||
0,
|
||||
KPack * (get_thread_local_1d_id() % warpSize)));
|
||||
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
// Cast after lds
|
||||
@@ -2042,24 +2044,76 @@ struct GridwiseMoeGemm
|
||||
static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
|
||||
auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
|
||||
auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
|
||||
decltype(c_thread_buf) c_thread_buf_up;
|
||||
|
||||
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
|
||||
float,
|
||||
c_thread_buf.num_of_v_,
|
||||
c_thread_buf.s_per_v,
|
||||
true>
|
||||
c_thread_buf_fp32;
|
||||
|
||||
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
|
||||
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
|
||||
KPerBlock);
|
||||
|
||||
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
a_blockwise_copy,
|
||||
a_grid_buf,
|
||||
a_block_bufs,
|
||||
a_block_slice_copy_step,
|
||||
b_grid_desc_bpreshuffled,
|
||||
b_blockwise_copy,
|
||||
b_grid_buf,
|
||||
b_block_bufs,
|
||||
b_block_slice_copy_step,
|
||||
c_thread_buf,
|
||||
num_k_block_main_loop);
|
||||
if constexpr(IsInputGemm)
|
||||
{
|
||||
const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize;
|
||||
const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid_up + expert_id * expert_stride / BPackedSize,
|
||||
b_grid_desc_bpreshuffled.GetElementSpaceSize());
|
||||
auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2<
|
||||
BDataType,
|
||||
BDataType,
|
||||
decltype(b_grid_desc_bpreshuffled),
|
||||
decltype(b_block_desc_bk0_n_bk1),
|
||||
Sequence<Number<NXdlPerWave>{}, I1, Number<KRepeat>{}, Number<BK1Value>{}>,
|
||||
Sequence<1, 2, 0, 3>,
|
||||
3,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(b_grid_desc_bpreshuffled,
|
||||
make_multi_index(n_block_data_idx_on_grid,
|
||||
get_warp_local_1d_id() % NWave,
|
||||
0,
|
||||
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
|
||||
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
a_blockwise_copy,
|
||||
a_grid_buf,
|
||||
a_block_bufs,
|
||||
a_block_slice_copy_step,
|
||||
b_grid_desc_bpreshuffled,
|
||||
b_blockwise_copy,
|
||||
b_blockwise_copy_up,
|
||||
b_grid_buf,
|
||||
b_grid_buf_up,
|
||||
b_block_bufs,
|
||||
b_block_slice_copy_step,
|
||||
c_thread_buf,
|
||||
c_thread_buf_up,
|
||||
num_k_block_main_loop);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
a_blockwise_copy,
|
||||
a_grid_buf,
|
||||
a_block_bufs,
|
||||
a_block_slice_copy_step,
|
||||
b_grid_desc_bpreshuffled,
|
||||
b_blockwise_copy,
|
||||
b_grid_buf,
|
||||
b_block_bufs,
|
||||
b_block_slice_copy_step,
|
||||
c_thread_buf,
|
||||
num_k_block_main_loop);
|
||||
}
|
||||
|
||||
// shuffle C and write out
|
||||
{
|
||||
@@ -2087,6 +2141,185 @@ struct GridwiseMoeGemm
|
||||
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
|
||||
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
|
||||
|
||||
// mul scales
|
||||
const float* p_sorted_weights_0 = p_ds_grid[I0];
|
||||
const float* p_scale_b = p_ds_grid[I1];
|
||||
|
||||
static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock);
|
||||
static_assert(M4 == 4);
|
||||
const index_t m1 = get_warp_local_1d_id() / NWave;
|
||||
const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl;
|
||||
|
||||
if(p_sorted_weights_0 != nullptr && p_scale_b != nullptr)
|
||||
{
|
||||
if constexpr(PerTokenQuant)
|
||||
{
|
||||
constexpr index_t scale_stride = (IsInputGemm ? 2 : 1);
|
||||
p_scale_b += expert_id * problem.N * scale_stride + block_n_id * NPerBlock +
|
||||
get_warp_local_1d_id() % NWave * NPerXdl + threadIdx.x % NPerXdl;
|
||||
}
|
||||
else
|
||||
{
|
||||
p_scale_b += expert_id;
|
||||
}
|
||||
|
||||
vector_type<int32_t, 4> scale_token_ids;
|
||||
vector_type<float, 4> topk_weights;
|
||||
static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
|
||||
const float scale_b = p_scale_b[n0 * NWave * NPerXdl * PerTokenQuant];
|
||||
static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
|
||||
static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk
|
||||
const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 +
|
||||
m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
|
||||
if constexpr(PerTokenQuant)
|
||||
{
|
||||
scale_token_ids =
|
||||
*c_style_pointer_cast<const vector_type<int32_t, M4>*>(
|
||||
p_sorted_token_ids + m_pos);
|
||||
}
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
topk_weights = *c_style_pointer_cast<const vector_type<float, M4>*>(
|
||||
p_ds_grid[I2] + m_pos);
|
||||
}
|
||||
static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size
|
||||
float scale_a = [&]() {
|
||||
if constexpr(PerTokenQuant)
|
||||
{
|
||||
index_t fused_token = scale_token_ids.AsType<index_t>()[m4];
|
||||
const index_t token_offset = fused_token & 0xffffff;
|
||||
return token_offset < problem.NumTokens
|
||||
? p_sorted_weights_0[token_offset]
|
||||
: 0.0;
|
||||
}
|
||||
else
|
||||
{
|
||||
return p_sorted_weights_0[0];
|
||||
}
|
||||
}();
|
||||
constexpr index_t c_offset =
|
||||
blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
|
||||
make_tuple(m0, n0, m2 * M4 + m4));
|
||||
constexpr auto cidx = Number<c_offset>{};
|
||||
if constexpr(IsInputGemm) // gu fusion
|
||||
{
|
||||
if constexpr(ActivationOperation == Activation::silu_and_mul)
|
||||
{
|
||||
const float scale_up =
|
||||
p_scale_b[(n0 * NWave * NPerXdl + problem.N) *
|
||||
PerTokenQuant];
|
||||
float gate = scale_a * scale_b * c_thread_buf[cidx];
|
||||
float up = scale_a * scale_up * c_thread_buf_up[cidx];
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
gate = gate * topk_weights.AsType<float>()[m4];
|
||||
up = up * topk_weights.AsType<float>()[m4];
|
||||
}
|
||||
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
|
||||
{
|
||||
gate *= 16;
|
||||
up *= 16;
|
||||
}
|
||||
tensor_operation::element_wise::Silu{}(gate, gate);
|
||||
c_thread_buf_fp32(cidx) = gate * up;
|
||||
}
|
||||
else if(ActivationOperation == Activation::gelu_and_mul)
|
||||
{
|
||||
const float scale_up =
|
||||
p_scale_b[(n0 * NWave * NPerXdl + problem.N) *
|
||||
PerTokenQuant];
|
||||
float gate = scale_a * scale_b * c_thread_buf[cidx];
|
||||
float up = scale_a * scale_up * c_thread_buf_up[cidx];
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
gate = gate * topk_weights.AsType<float>()[m4];
|
||||
up = up * topk_weights.AsType<float>()[m4];
|
||||
}
|
||||
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
|
||||
{
|
||||
gate *= 16;
|
||||
up *= 16;
|
||||
}
|
||||
tensor_operation::element_wise::Gelu{}(gate, gate);
|
||||
c_thread_buf_fp32(cidx) = gate * up;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
c_thread_buf_fp32(cidx) =
|
||||
scale_a * scale_b * c_thread_buf[cidx];
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
c_thread_buf_fp32(cidx) = c_thread_buf_fp32(cidx) *
|
||||
topk_weights.AsType<float>()[m4];
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
vector_type<float, 4> topk_weights; // for gemm2 only
|
||||
static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
|
||||
static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
|
||||
static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk
|
||||
const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 +
|
||||
m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
topk_weights = *c_style_pointer_cast<const vector_type<float, M4>*>(
|
||||
p_ds_grid[I2] + m_pos);
|
||||
}
|
||||
static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size
|
||||
constexpr index_t c_offset =
|
||||
blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
|
||||
make_tuple(m0, n0, m2 * M4 + m4));
|
||||
constexpr auto cidx = Number<c_offset>{};
|
||||
|
||||
if constexpr(IsInputGemm) // gu fusion
|
||||
{
|
||||
if constexpr(ActivationOperation == Activation::silu_and_mul)
|
||||
{
|
||||
float gate = c_thread_buf[cidx];
|
||||
float up = c_thread_buf_up[cidx];
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
gate = gate * topk_weights.AsType<float>()[m4];
|
||||
up = up * topk_weights.AsType<float>()[m4];
|
||||
}
|
||||
tensor_operation::element_wise::Silu{}(gate, gate);
|
||||
c_thread_buf_fp32(cidx) = gate * up;
|
||||
}
|
||||
else if(ActivationOperation == Activation::gelu_and_mul)
|
||||
{
|
||||
float gate = c_thread_buf[cidx];
|
||||
float up = c_thread_buf_up[cidx];
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
gate = gate * topk_weights.AsType<float>()[m4];
|
||||
up = up * topk_weights.AsType<float>()[m4];
|
||||
}
|
||||
tensor_operation::element_wise::Gelu{}(gate, gate);
|
||||
c_thread_buf_fp32(cidx) = gate * up;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
c_thread_buf_fp32(cidx) = topk_weights.AsType<float>()[m4] *
|
||||
c_thread_buf_fp32[cidx];
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
|
||||
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
|
||||
|
||||
@@ -2184,18 +2417,8 @@ struct GridwiseMoeGemm
|
||||
|
||||
const auto ds_grid_buf = generate_tuple(
|
||||
[&](auto i) {
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
const DDataType* ptr_ = p_ds_grid[i];
|
||||
// hack logic here to support different kind of strides. todo fix it.
|
||||
// ascale t, 1; bscale E, N, 1, move ptr to E
|
||||
// if(i.value == 1)
|
||||
// {
|
||||
// ptr_ +=
|
||||
// expert_id * (problem.StrideDs[1] ? problem.StrideDs[1] * problem.N :
|
||||
// 1);
|
||||
// }
|
||||
return make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
ptr_, ds_grid_desc_m_n[i].GetElementSpaceSize());
|
||||
p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
|
||||
@@ -2271,7 +2494,6 @@ struct GridwiseMoeGemm
|
||||
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
// space filling curve for threadwise C in VGPR
|
||||
constexpr auto sfc_c_vgpr =
|
||||
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
@@ -2310,7 +2532,7 @@ struct GridwiseMoeGemm
|
||||
block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
|
||||
static_for<0, EMRepeats, 1>{}([&](auto m0) {
|
||||
const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
|
||||
index_t token_offset = fused_token & 0xffffff;
|
||||
IndexType token_offset = fused_token & 0xffffff;
|
||||
if constexpr(IsInputGemm)
|
||||
{
|
||||
token_offset = token_offset * problem.TopK + (fused_token >> 24);
|
||||
@@ -2323,7 +2545,7 @@ struct GridwiseMoeGemm
|
||||
// each thread write its data from VGPR to LDS
|
||||
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
|
||||
c_thread_buf,
|
||||
c_thread_buf_fp32,
|
||||
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
c_shuffle_block_buf);
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
2652
include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp
Normal file
2652
include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp
Normal file
File diff suppressed because it is too large
Load Diff
2849
include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp
Normal file
2849
include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user