From b12ae84a40150e192bac69e1ac677bba5a1cdf1c Mon Sep 17 00:00:00 2001 From: huaiguxu <145733371+huaiguxu@users.noreply.github.com> Date: Fri, 27 Jun 2025 10:24:34 +0800 Subject: [PATCH] Huaiguxu/moe fp8 pertoken scale fix (#2391) * fix pertoken_scale a_scale dimension * clang-format * Fix moe_gemm2_fp8 perTokenScale reference and example. [ROCm/composable_kernel commit: e1c5172fdb7eb4072943696f6a33937234843e3b] --- .../moe_gemm2_xdl_fp8.cpp | 14 +++++++++----- .../gpu/grid/gridwise_moe_gemm.hpp | 14 ++++++++++++-- .../cpu/reference_moe_gemm2.hpp | 7 +++++-- 3 files changed, 26 insertions(+), 9 deletions(-) diff --git a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp index 3188ba142c..6a3986ea32 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp @@ -139,6 +139,7 @@ static constexpr ck::index_t EVec = 2; static constexpr ck::index_t D0Vec = 1; static constexpr ck::index_t D1Vec = 1; static constexpr ck::index_t D2Vec = 1; +static constexpr bool PerTokenQuant = true; static constexpr bool MulRoutedWeight = true; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm // clang-format off @@ -169,7 +170,7 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic // MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| 2, 2, S<1, CShuffleMLane, 1, CShuffleNLane>, S, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, MulRoutedWeight, false, int32_t, A0DataType>; + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, MulRoutedWeight, PerTokenQuant, int32_t, A0DataType>; // kernel 2: 128->32x128x128 // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>; @@ -197,7 +198,7 @@ int main(int argc, char* argv[]) { // use default case } - else if(argc == 3) + else if(argc == 4) { // use default case do_verification = std::stoi(argv[1]); @@ -238,7 +239,8 @@ int main(int argc, char* argv[]) ck::index_t StrideB = K; ck::index_t StrideE = N; constexpr ck::index_t NumDTensor = DsDataType::Size(); - constexpr auto StrideDs = std::array{0, 0, 0}; + constexpr auto StrideDs = PerTokenQuant ? std::array{1, 1, 0} + : std::array{0, 0, 0}; ck::index_t KBatch = 1; @@ -279,8 +281,10 @@ int main(int argc, char* argv[]) Tensor a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1})); Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); - Tensor d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0})); - Tensor d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]})); + Tensor d0_t_n( + HostTensorDescriptor({tokens, topk, N}, {StrideDs[0] * topk, StrideDs[0], 0})); + Tensor d1_e_n( + HostTensorDescriptor({experts, N}, {PerTokenQuant ? StrideDs[1] * N : 1, StrideDs[1]})); Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); Tensor e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1})); Tensor e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1})); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp index 92aab5af52..36f8fd7cc1 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp @@ -1473,7 +1473,12 @@ struct GridwiseMoeGemm index_t fused_token = scale_token_ids.AsType()[m4]; const index_t token_offset = fused_token & 0xffffff; return token_offset < problem.NumTokens - ? p_sorted_weights_0[token_offset] + ? p_sorted_weights_0[IsInputGemm + ? token_offset + : token_offset * + problem.TopK + + (fused_token >> + 24)] : 0.0; } else @@ -2190,7 +2195,12 @@ struct GridwiseMoeGemm index_t fused_token = scale_token_ids.AsType()[m4]; const index_t token_offset = fused_token & 0xffffff; return token_offset < problem.NumTokens - ? p_sorted_weights_0[token_offset] + ? p_sorted_weights_0[IsInputGemm + ? token_offset + : token_offset * + problem.TopK + + (fused_token >> + 24)] : 0.0; } else diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp index 583d704040..58e4adfdfa 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp @@ -6,6 +6,7 @@ #include #include #include +#include #include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp" @@ -85,6 +86,7 @@ struct ReferenceMoeGemm2 : public device::BaseOperator float Run(const Argument& arg) { + std::vector n_locks(arg.c_t_n_.mDesc.GetLengths()[1]); arg.c_t_n_.SetZero(); auto f_mk_kn_mn = [&](auto m, auto n) { const int K = arg.a_t_k_k_.mDesc.GetLengths()[2]; @@ -142,8 +144,8 @@ struct ReferenceMoeGemm2 : public device::BaseOperator ck::type_convert(v_a) * ck::type_convert(v_b); } CDataType v_c{0}; - D0DataType v_d0 = arg.d0_(m, n); // a - D0DataType v_d1 = arg.d1_(e, n); // b + D0DataType v_d0 = arg.d0_(t, topk_id); // a + D0DataType v_d1 = arg.d1_(e, n); // b if constexpr(MulRoutedWeight) { arg.c_element_op_(v_c, v_acc, v_d0, v_d1, v_topk_w); @@ -152,6 +154,7 @@ struct ReferenceMoeGemm2 : public device::BaseOperator { arg.c_element_op_(v_c, v_acc, v_d0, v_d1, 1.f); } + std::lock_guard lock(n_locks[n]); arg.c_t_n_(t, n) += v_c; } };